Check if we get a GPU

In [None]:
!nvidia-smi

Google colab now has fastai2, while the funcationaliry are essentially identical, they changed some library names between v1 and v2

In [None]:
!pip install fastai==1.0.61

In [None]:
import fastai.callback

In [None]:
import pandas as pd
!unzip -q sample_data/singletrialfish_Aug2022.zip
fishdf=pd.read_csv('sample_data/trialfish_Aug2022_tmp.csv', header=0, index_col=0)

In [None]:
testdf2 = []

for i in range(400):
  si = (i+1)*200-200
  ei = (i+1)*200-100
  tmpdf = fish[si:ei]
  testdf2.append(tmpdf)
newdf = pd.concat(testdf2)

Dataloader for our particular dataset (#neuron X time X time => future movie frame)

In [None]:
import torch
from fastai import *
from fastai.vision import *
import PIL

def open_grammian_to_singleImage(fname):
  mat=torch.load(fname)
  mat = mat.type(torch.FloatTensor)
  return Image(mat)

class grammiansingleImageImageImageList(ImageImageList):
    def open(self, fn):
        return open_grammian_to_singleImage(fn)


class Multi_to_MultiGrammianList(grammiansingleImageImageImageList):
    "`ItemList` suitable for `Image` to `Image` tasks."
    _label_cls,_square_show,_square_show_res = grammiansingleImageImageImageList,False,False

In [None]:
x_cols=list(['trn1'])
y_cols=list(['tst'])
il = grammiansingleImageImageImageList.from_df(path='.',df=newdf_fish, cols=x_cols)
ils = il.split_by_rand_pct(0.1, seed=42)
#cls=list(myDict['tst'])
tfms = get_transforms(flip_vert=False, do_flip=False, 
                      max_rotate=10, max_zoom=1.01, max_lighting=None, max_warp=None, 
                      p_affine=0., p_lighting=0.)
ils2 = ils.label_from_df(cols=y_cols).transform(tfms, size=128, tfm_y=True)
print(ils2)

A usual GPU avail in Google colab has 16GB vram. Since we are using a pretrained VGG19 as our loss, we can at most fit in 32 images in one batch (500MB X 32) 

In [None]:
bs=32
data = ils.databunch(bs=bs,num_workers=0)

Define the loss function

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)
from torchvision.models import vgg19_bn
from fastai.callbacks import *
from fastai.utils.mem import *
base_loss = F.l1_loss
vgg_m = vgg19_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)] #layer id 
blocks, [vgg_m[i] for i in blocks]
print(blocks)
#this feature loss works in case we use another pretrained network to compute the difference 
#if we do not do that, we will need to use a different loss function (VGG is way too memory expensive)
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat #VGG19 is m_feat here
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        #Siwei why hooks here: keep the activations
        self.hooks = hook_outputs(self.loss_features, detach=False) #need remove (no context manager here)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x) #Siwei just copy (cloning it won’t involve autograd.)
        
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True) #activations of VGG19 (clone the activations only)
        in_feat = self.make_features(input) #activations of the Unet(resnet34)
        self.feat_losses = [base_loss(input,target)] #w/ L1, this is l1 loss between input and target
        self.feat_losses += [base_loss(f_in, f_out)*w #base_loss is L1
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3 #L1 between gram matrices
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

Now we define a U-net

In [None]:
from VUnet_def.py import *

In [None]:
wd = 1e-3
learn = unet_vae_learner(data, models.resnet18, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics,blur=True, norm_type=NormType.Weight, latentdims=(10,10), last_cross= False).to_fp16()
gc.collect();

In [None]:
learn.lr_find()
learn.recorder.plot()