In [None]:
!pip3 install --no-deps cell-data-loader
!pip3 install tkinter

In [None]:
from cell_data_loader import CellDataloader
import torch,os,glob
from torchvision.models import resnet50 #, ResNet50_Weights
import tkinter as tk
import tkinter.filedialog

In [2]:
### CHANGE THESE LINES
wd = '/Users/mleming/Desktop/CellDataLoader_playground'
train_folder_name = os.path.join(wd,'dropbox_downloads','train')
test_folder_name = os.path.join(wd,'dropbox_downloads','test')
use_tkinter = True
root = tk.Tk()
if use_tkinter:
    print("Select your working directory to save the model")
    wd = tk.filedialog.askdirectory(parent=root,title="Please select the directory to save the model")
    print("Select the directory with the train cell data")
    train_folder_name = tk.filedialog.askdirectory(title="Please select the directory with training cell data")
    print("Select the directory with the test cell data")
    test_folder_name = tk.filedialog.askdirectory(title="Please select the directory with test cell data")
print(f"Working directory to save model: {wd}")
print(f"Location of train data: {train_folder_name}")
print(f"Location of test data: {test_folder_name}")

## These strings are used to match between the files with one label and those with another
## In this case, if some folders have the word "blurry" and others have the word "clear",
## they will have labels 1 and 2. Note that CellDataLoader pushes a warning if any given
## path matches both strings (i.e., "/this/path/is/Blurry/and/Clear/img.png")

label_regex_strings = ["Blurry","Clear"]

Select your working directory to save the model
Select the directory with the train cell data
Select the directory with the test cell data
Working directory to save model: /Users/mleming/Desktop/CellDataLoader_playground
Location of train data: /Users/mleming/Desktop/CellDataLoader_playground/dropbox_downloads/train
Location of test data: /Users/mleming/Desktop/CellDataLoader_playground/dropbox_downloads/test


In [3]:
os.makedirs(wd, exist_ok = True)

# Checkpoints
model_folder = os.path.join(wd,'checkpoints')
os.makedirs(model_folder,exist_ok=True)
model_file = os.path.join(model_folder,'torch_model.pt')

In [4]:
def train_torch(model,
                dataloader,
                model_file = None,
                epochs = 50,
                gpu_ids = None,
                verbose = True,
                loss_break=None):

    if gpu_ids is not None:
        model.to(gpu_ids)

    # Train

    model.train()
    loss_fn = torch.nn.CrossEntropyLoss()
    #loss_fn = torch.nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    if verbose: print("Beginning training")
    for epoch in range(epochs):
        loss_total,count = 0,0
        for image,y in dataloader:
            y_pred = model(image)
            y = torch.nn.functional.pad(y,(0,y_pred.size()[1]-y.size()[1]))
            assert(all(y.sum(1) == 1))
            loss = loss_fn(y_pred, y)#torch.argmax(y,1))
            loss_total += float(loss)
            count += 1
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if model_file is not None:
            torch.save(model,model_file)
        if verbose:
            print(
                "Epoch {epoch:d}/{epochs:d}: loss: {loss:.5f}".format(
                    epoch=epoch,epochs=epochs,loss=loss_total/count)
            )
        if loss_break is not None and loss_total/count < loss_break:
            if verbose: print("Loss break - exiting")
            return

def test_torch(model,
               dataloader,
               gpu_ids = None,
               verbose=True):
    
    model.eval()
    total_images = 0
    sum_accuracy = 0
    for image,y in dataloader:
        total_images += image.size()[0]
        y_pred = model(image)
        y_pred = y_pred[:,:y.size()[1]]
        sum_accuracy += torch.sum(torch.argmax(y_pred,axis=1) == \
            torch.argmax(y,axis=1))
    accuracy = sum_accuracy / total_images
    if verbose: print("Final accuracy: %.4f" % accuracy)


In [8]:
# Get a pretrained model from torchvision
if os.path.isfile(model_file):
    print("Loading %s" % model_file)
    model = torch.load(model_file)
else:
    # Download a pretrained resnet and edit its outputs to be compatible with cross entropy loss
    model = resnet50(pretrained=True)
    #model.fc = torch.nn.Sequential(
    #    torch.nn.Dropout(0.5),
    #    torch.nn.Linear(2048, 1000)
    #)

print("Preparing Train Data")
dataloader_train = CellDataloader(train_folder_name,label_regex=label_regex_strings,
    dtype = "torch",
    verbose = True,
    batch_size = 64,
    gpu_ids = None,
    n_channels=3)

print("Preparing Test Data")
dataloader_test = CellDataloader(test_folder_name,label_regex=label_regex_strings,
    dtype = "torch",
    verbose = True,
    batch_size = 64,
    gpu_ids = None,
    n_channels=3)

Loading /Users/mleming/Desktop/CellDataLoader_playground/checkpoints/torch_model.pt
Preparing Train Data
Detected label format: Regex
427 image paths read
3 Channels Detected
Preparing Test Data
Detected label format: Regex
50 image paths read
3 Channels Detected


In [None]:
train_torch(model,dataloader_train,model_file)

In [None]:
test_torch(model,dataloader_test)