In [None]:
!pip install torchvision==0.10.0

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torchvision

In [None]:
print(torch.__version__,
torchvision.__version__)

In [None]:
#Clone Recursion Pharma's utilities
!git clone https://github.com/recursionpharma/rxrx1-utils.git

sys.path.append('/kaggle/working/rxrx1-utils')
import rxrx.io as rio

In [None]:
#Loading in and resizing the images to 224x224 for resnet

def load_and_resize(dataset, experiment, plate, well, site):
    img = rio.load_site(dataset, experiment, plate, well, site, base_path='../input/recursion-cellular-image-classification/')
    resized = cv2.resize(img, (224,224)).astype(np.float32)
    resized = torch.from_numpy(resized).permute(2,0,1)
    return resized

In [None]:
#Had to make my own dataset class, the torch example using ImageFolder
#and torch transforms wouldn't work on my dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, list_IDs, labels):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        # Load data and get label
        exp = train_df.experiment.iloc[ID]
        plate = train_df.plate.iloc[ID]
        well = train_df.well.iloc[ID]
        site = train_df.site.iloc[ID]
        
        X = load_and_resize('train', exp, plate, well, site)
        y = self.labels[ID]

        return X, y

In [None]:
#Getting a subset of the data to test with

train_df = pd.read_csv('../input/recursion-cellular-image-classification/train.csv')
train_df = train_df.iloc[0:4]
train_df['site'] = 1
indexes = [0,1,2,3]
sirnas = [250, 60, 43, 20]

In [None]:
#Importing a pre-trained ResNet model for transfer learning and freezing all but the last layer
resnet = torchvision.models.resnet18(pretrained=True)
for param in resnet.parameters():
    param.requires_grad = False
    
#Redefining the final fully connected layer
infeat = resnet.fc.in_features
nclasses = 1139
resnet.fc = nn.Linear(infeat, nclasses)

#Adding a convolutional layer to match the channels to the resnet model
first_conv_layer = [torch.nn.Conv2d(6, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)]
first_conv_layer.extend(list(resnet.children()))
resnet = torch.nn.Sequential(*first_conv_layer)

resnet = resnet.to(device)

criterion = nn.CrossEntropyLoss()

optimizer_conv = torch.optim.SGD(filter(lambda p: p.requires_grad, resnet.parameters()), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

In [None]:
dataloader = torch.utils.data.DataLoader(Dataset(indexes, sirnas), batch_size=4, shuffle=True, num_workers=4)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for inputs, labels in dataloader:
    inputs = inputs.to(device)
    labels = torch.tensor(labels).to(device)
    outputs = resnet(inputs)