<a href="https://colab.research.google.com/github/pglez82/IFCB_semisupervised/blob/master/IFCB_FT_Baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load the data
We are going to finetune a resnet18 and extract features with it

In [1]:
import os


if not os.path.isfile("IFCB_data.tar") and not os.path.isdir("data"):
  print("Data do not exist in local. Downloading...")
  !wget -O IFCB_data.tar https://unioviedo-my.sharepoint.com/:u:/g/personal/gonzalezgpablo_uniovi_es/Ec2z0uC4lghEg-9MjzoJ9QkBK5n74QjS-LszB9dlNrPfaw?download=1
else:
  print("Data already exists. Skipping download.")

if not os.path.isdir("data"):
  print("Extracting the tar file...")
  !tar -xf "IFCB_data.tar"
  print("Done. Removing the tar file.")
  !rm -f IFCB_data.tar #Remove the original file to save space

Data already exists. Skipping download.


# Download CSV with information about the images


In [2]:
import pandas as pd

if not os.path.isfile('IFCB.csv.zip'):
  print("CSV data do not exist. Downloading...")
  !wget -O IFCB.csv.zip "https://unioviedo-my.sharepoint.com/:u:/g/personal/gonzalezgpablo_uniovi_es/EfsVLhFsYJpPjO0KZlpWUq0BU6LaqJ989Re4XzatS9aG4Q?download=1"

data = pd.read_csv('IFCB.csv.zip',compression='infer', header=0,sep=',',quotechar='"')
print(data)

                        Sample  roi_number        OriginalClass  \
0        IFCB1_2006_158_000036           1                  mix   
1        IFCB1_2006_158_000036           2  Tontonia_gracillima   
2        IFCB1_2006_158_000036           3                  mix   
3        IFCB1_2006_158_000036           4                  mix   
4        IFCB1_2006_158_000036           5                  mix   
...                        ...         ...                  ...   
3457814  IFCB5_2014_353_205141        6850       Leptocylindrus   
3457815  IFCB5_2014_353_205141        6852                  mix   
3457816  IFCB5_2014_353_205141        6855                  mix   
3457817  IFCB5_2014_353_205141        6856                  mix   
3457818  IFCB5_2014_353_205141        6857                  mix   

              AutoClass FunctionalGroup  
0                   mix      Flagellate  
1           ciliate_mix         Ciliate  
2                   mix      Flagellate  
3                   mix    

# Create training set

Here we make a reestructuration of the images depending on which class we consider

In [3]:
import progressbar
from tqdm import tqdm
tqdm.pandas()

classcolumn = "AutoClass" #Autoclass means 51 classes
yearstraining = ['2006'] #Years to consider as training
yearsvalidation = ['2007']
trainingfolder = "training"
validationfolder = "validation"

classes = pd.unique(data[classcolumn])
print("Considering %i classes" % len(classes))

print("Computing image paths...")
#Compute data paths
data['year'] = data['Sample'].str[6:10].astype(str)
data['path']="data"+'/'+data['year']+'/'+data['OriginalClass'].astype(str)+'/'+data['Sample'].astype(str)+'_'+data['roi_number'].apply(lambda x: str(x).zfill(5))+'.png'
print('Done')

if not os.path.isdir(trainingfolder):
  print("Create folder structure for training set...")
  os.mkdir(trainingfolder)
  for folder in classes:
    os.mkdir(os.path.join(trainingfolder,folder))
  print("Done.\nMoving images to the respective folders...")
  data[data['year'].isin(yearstraining)].progress_apply(lambda row: os.rename(row['path'],os.path.join(trainingfolder,row[classcolumn],os.path.basename(row['path']))),axis=1)
  print("Done")
else:
  print("Training data already there... Doing nothing")

if not os.path.isdir(validationfolder):
  print("Create folder structure for the validation set...")
  os.mkdir(validationfolder)
  for folder in classes:
    os.mkdir(os.path.join(validationfolder,folder))
  print("Done.\nMoving images to the respective folders...")
  data[data['year'].isin(yearsvalidation)].progress_apply(lambda row: os.rename(row['path'],os.path.join(validationfolder,row[classcolumn],os.path.basename(row['path']))),axis=1)
  print("Done")  
else:
  print("Validation data already there... Doing nothing")


Considering 51 classes
Computing image paths...


  from pandas import Panel


Done
Training data already there... Doing nothing
Validation data already there... Doing nothing


# Configure the process

In [4]:
import torch

num_workers = 16 # @param
batch_size = 256 # @param 
train_dir = './training'
val_dir = './validation'
num_epochs_ft1 = 10 # @param
num_epochs_ft2 = 10 # @param
proportion = 1 #How many labelled examples do we take

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print("Using %s"%device)

Using cuda:1


# Prepare de DataLoaders for the CNN
In this step it is important to consider that we have to use images with the same size than the original network (so we can reuse the weights)

In [8]:
import torchvision
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np

def create_balanced_splits(train_loader,proportions):
  """
  This function creates different balanced splits following the proportions
  """
  labels_vector = []
  for x,y in train_loader:
    labels_vector.extend(y.numpy())
  unique, globalcounts = np.unique(labels_vector, return_counts=True)
  #Find indexes for each class
  classindexes = []
  for c in unique:
    classindexes.append(np.where(labels_vector == c)[0])
  subsets = {}
  for p in proportions:
    subsets[p]=[]
    counts = np.rint(globalcounts*p)
    print("Building subset with %d elements"%sum(counts))
    for i in range(len(counts)):
      subsets[p].extend(classindexes[i][0:int(counts[i])])
  return subsets

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

#Define transofrmations
train_transform = T.Compose([
  T.Resize(size=256),
  T.RandomResizedCrop(size=224),
  T.RandomHorizontalFlip(),
  T.ToTensor(),            
  #T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

val_transform = T.Compose([
  T.Resize(size=256),
  T.CenterCrop(size=224),
  T.ToTensor(),
  #T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
  

#Define data loader
train_dset = ImageFolder(train_dir, transform=train_transform)
train_loader = DataLoader(train_dset,batch_size=batch_size,num_workers=num_workers,shuffle=False)
num_classes = len(train_dset.classes)
indexes=create_balanced_splits(train_loader,[proportion])[proportion] #We only pass one value, this function if prepared to receive multiple proportions
train_dset=torch.utils.data.Subset(train_dset,indexes)
train_loader = DataLoader(train_dset,batch_size=batch_size,num_workers=num_workers,shuffle=True)
print("Working with %f of the current data"%proportion)

val_dset = ImageFolder(val_dir, transform=val_transform)
val_loader = DataLoader(val_dset,batch_size=batch_size,num_workers=num_workers)

Building subset with 131002 elements
Working with 1.000000 of the current data


# Load the CNN
In this step we download a pretrained CNN with the weights from ImageNet. We change the last layer to match the number of classes that we have in our problem

In [9]:
import torch.nn as nn

model = torchvision.models.resnet18(pretrained=True)
print("Adjusting the CNN for %s classes" % num_classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
#Define loss function
loss_fn = nn.CrossEntropyLoss()
model = model.to(device)

Adjusting the CNN for 51 classes


# Perform finetuning
First we only update the last layer for a few epochs, then we update all the weights with a small learning rate

In [11]:
import time

def run_epoch(model, loss_fn, loader, optimizer, device):
  """
  Train the model for one epoch.
  """
  loss_epoch = 0
  start_time = time.time()
  # Set the model to training mode
  model.train()
  for step, (x, y) in enumerate(loader):
    
    x = x.to(device)
    y = y.to(device)

    # Run the model forward to compute scores and loss.
    scores = model(x)
    loss = loss_fn(scores, y)
    loss_epoch = loss_epoch + loss.item()
    # Run the model backward and take a step using the optimizer.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 50== 0:
      spent = time.time()-start_time
      print(f"Step [{step}/{len(loader)}]\t Loss: {loss.item()} \t Time: {spent} secs [{(batch_size*50)/spent} ej/sec]]")
      start_time = time.time()
  return loss_epoch

def check_accuracy(model, loader, device):
  """
  Check the accuracy of the model.
  """
  # Set the model to eval mode
  accuracy_epoch = 0
  model.eval()
  num_correct, num_examples = 0, 0
  for x, y in loader:
    x = x.to(device)
    y = y.to(device)
    # Run the model forward, and compare the argmax score with the ground-truth
    # category.
    output = model(x)
    #_, preds = scores.data.cpu().max(1)
    #num_correct += (preds == y).sum()
    #num_examples += x.size(0)
    predicted = output.argmax(1)
    acc = (predicted == y).sum().item() / y.size(0)
    accuracy_epoch += acc

  # Return the fraction of datapoints that were correctly classified.
  #acc = float(num_correct) / num_examples
  accuracy_epoch = float(accuracy_epoch) / len(loader)
  return accuracy_epoch

for param in model.parameters():
  param.requires_grad = False
for param in model.fc.parameters():
  param.requires_grad = True

optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

#First phase of finetuning
for epoch in range(num_epochs_ft1):
  # Run an epoch over the training data.
  print('Starting epoch %d / %d' % (epoch + 1,num_epochs_ft1))
  loss_epoch = run_epoch(model, loss_fn, train_loader, optimizer, device)

  # Check accuracy on the train and val sets.
  #train_acc = check_accuracy(model, train_loader, device)
  print(f"Epoch [{epoch+1}/{num_epochs_ft1}]\t Loss: {loss_epoch / len(train_loader)}")

#Allow updating all the weights in the second phase
for param in model.parameters():
  param.requires_grad = True

#Lower learning rate this time
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# Train the entire model for a few more epochs, checking accuracy on the
# train sets after each epoch.
for epoch in range(num_epochs_ft2):
  print('Starting epoch %d / %d' % (epoch + 1, num_epochs_ft2))
  loss_epoch = run_epoch(model, loss_fn, train_loader, optimizer, device)

  print(f"Epoch [{epoch+1}/{num_epochs_ft2}]\t Loss: {loss_epoch / len(train_loader)}")
    
print("Performing final validation in test examples...")
val_acc = check_accuracy(model, val_loader, device)
print('Val accuracy: ', val_acc)


Starting epoch 1 / 10
Step [0/512]	 Loss: 0.7009488344192505 	 Time: 2.3488380908966064 secs [5449.502905121034 ej/sec]]
Step [50/512]	 Loss: 0.6747325658798218 	 Time: 7.595005750656128 secs [1685.3180129447849 ej/sec]]
Step [100/512]	 Loss: 0.6512468457221985 	 Time: 7.529764175415039 secs [1699.9204359935304 ej/sec]]
Step [150/512]	 Loss: 0.6653121709823608 	 Time: 7.5224809646606445 secs [1701.5662864595147 ej/sec]]
Step [200/512]	 Loss: 0.6299343705177307 	 Time: 7.591899156570435 secs [1686.0076426228866 ej/sec]]
Step [250/512]	 Loss: 0.6107427477836609 	 Time: 7.587847948074341 secs [1686.9078146522968 ej/sec]]
Step [300/512]	 Loss: 0.5977334976196289 	 Time: 7.549274206161499 secs [1695.5272322143248 ej/sec]]
Step [350/512]	 Loss: 0.5632631182670593 	 Time: 7.542078256607056 secs [1697.144946591726 ej/sec]]
Step [400/512]	 Loss: 0.573750376701355 	 Time: 7.554845809936523 secs [1694.2768022035314 ej/sec]]
Step [450/512]	 Loss: 0.6550976037979126 	 Time: 7.556173324584961 secs [

Step [50/512]	 Loss: 0.5511623024940491 	 Time: 7.610683441162109 secs [1681.8463281197137 ej/sec]]
Step [100/512]	 Loss: 0.5964354276657104 	 Time: 7.53943395614624 secs [1697.74018506592 ej/sec]]
Step [150/512]	 Loss: 0.6695919632911682 	 Time: 7.532733201980591 secs [1699.2504124046873 ej/sec]]
Step [200/512]	 Loss: 0.4962494969367981 	 Time: 7.540668725967407 secs [1697.4621834163472 ej/sec]]
Step [250/512]	 Loss: 0.5630505084991455 	 Time: 7.551358938217163 secs [1695.0591416360369 ej/sec]]
Step [300/512]	 Loss: 0.6170036196708679 	 Time: 7.5548319816589355 secs [1694.279903388308 ej/sec]]
Step [350/512]	 Loss: 0.6068161129951477 	 Time: 7.551878213882446 secs [1694.9425874572566 ej/sec]]
Step [400/512]	 Loss: 0.7377872467041016 	 Time: 7.553398132324219 secs [1694.6015257984786 ej/sec]]
Step [450/512]	 Loss: 0.5438460111618042 	 Time: 7.560569763183594 secs [1692.9940997740619 ej/sec]]
Step [500/512]	 Loss: 0.6156167387962341 	 Time: 7.504220008850098 secs [1705.706920226796 ej/s

Step [100/512]	 Loss: 0.5170642137527466 	 Time: 18.543452501296997 secs [690.2705954624535 ej/sec]]
Step [150/512]	 Loss: 0.3893302083015442 	 Time: 18.580814123153687 secs [688.8826245804713 ej/sec]]
Step [200/512]	 Loss: 0.5456705093383789 	 Time: 18.522735118865967 secs [691.0426520629134 ej/sec]]
Step [250/512]	 Loss: 0.37861368060112 	 Time: 18.45746088027954 secs [693.4865029932623 ej/sec]]
Step [300/512]	 Loss: 0.3915737271308899 	 Time: 18.597198724746704 secs [688.2757015962541 ej/sec]]
Step [350/512]	 Loss: 0.4708123505115509 	 Time: 18.5426926612854 secs [690.2988812797747 ej/sec]]
Step [400/512]	 Loss: 0.4391838610172272 	 Time: 18.686509370803833 secs [684.9861440681356 ej/sec]]
Step [450/512]	 Loss: 0.476293683052063 	 Time: 18.476900339126587 secs [692.7568891463244 ej/sec]]
Step [500/512]	 Loss: 0.3648203909397125 	 Time: 18.53321099281311 secs [690.6520410825539 ej/sec]]
Epoch [5/10]	 Loss: 0.4513296154909767
Starting epoch 6 / 10
Step [0/512]	 Loss: 0.524068474769592