In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# from fastai.vision import *
import torch
import json
import torch.nn.functional as F

from pathlib import Path
from torch import Tensor
from mrnet_orig import MR3DImageList, MRNet
from learn import *
from progress import *

%matplotlib inline

In [3]:
! tree -d ..

[01;34m..[00m
├── [01;34mdata[00m
│   ├── [01;34maxial[00m
│   │   ├── [01;34mtrain[00m
│   │   └── [01;34mvalid[00m
│   ├── [01;34mcoronal[00m
│   │   ├── [01;34mtrain[00m
│   │   └── [01;34mvalid[00m
│   └── [01;34msagittal[00m
│       ├── [01;34mmodels[00m
│       ├── [01;34mtrain[00m
│       └── [01;34mvalid[00m
└── [01;34mmrnet-fastai[00m
    ├── [01;34mexp[00m
    └── [01;34m__pycache__[00m

14 directories


In [4]:
! ls

callbacks.py	 loss_weights.pt	     progress.py
CONTRIBUTORS.md  MRNet_EDA.ipynb	     __pycache__
df_abnl.pkl	 MRNet_EDA_ns.ipynb	     README.md
exp		 MRNet_fastai_example.ipynb  slice_stats.json
exports.ipynb	 MRNet_fastai_v2.ipynb	     train_cases.pkl
learn.py	 mrnet_orig.py
LICENSE		 notebook2script.py


In [5]:
! ls ../data

axial	  train-abnormal.csv  valid-abnormal.csv
coronal   train-acl.csv       valid-acl.csv
sagittal  train-meniscus.csv  valid-meniscus.csv


In [6]:
data_path = Path('../data')
sag_path = data_path/'sagittal'
cor_path = data_path/'coronal'
ax_path = data_path/'axial'

## Substantial class imbalance for the normal/abnormal task

Given this, we'll derive weights for a weighted binary cross entropy loss function.

In [7]:
train_abnl = pd.read_csv(data_path/'train-abnormal.csv', header=None,
                       names=['Case', 'Abnormal'], 
                       dtype={'Case': str, 'Abnormal': np.int64})
print(train_abnl.shape)
train_abnl.head()

(1130, 2)


Unnamed: 0,Case,Abnormal
0,0,1
1,1,1
2,2,1
3,3,1
4,4,1


In [8]:
w = train_abnl.Abnormal.sum() / train_abnl.shape[0]
print(w)
weights = Tensor([w, 1-w])
print(weights)
torch.save(weights, 'loss_weights.pt')

0.8079646017699115
tensor([0.8080, 0.1920])


In [9]:
weights = torch.load('loss_weights.pt')

## Load previously created files

- `df_abnl` -> master `df` for use with Data Block API, also contains # of slices per series
- `slice_stats` -> `dict` stored as `json` with mean and max # of slices per series

In [10]:
df_abnl = pd.read_pickle('df_abnl.pkl')
df_abnl.head()

Unnamed: 0,Case,Abnormal,is_valid,coronal_slices,sagittal_slices,axial_slices
0,train/0000,1,0,25,27,25
1,train/0001,1,0,22,23,28
2,train/0002,1,0,24,24,24
3,train/0003,1,0,22,21,25
4,train/0004,1,0,30,30,31


In [11]:
with open('slice_stats.json', 'r') as file:
    stats = json.load(file)
    
stats

{'coronal': {'mean': 29.6416, 'max': 57},
 'sagittal': {'mean': 30.3776, 'max': 51},
 'axial': {'mean': 34.2032, 'max': 61}}

In [12]:
max_slc = stats['sagittal']['max']
print(max_slc)

51


## MRNet implementation

Modified from the original [paper](https://journals.plos.org/plosmedicine/article?id=10.1371/journal.pmed.1002699) to (sort of) work with `fastai`

In [13]:
il = MR3DImageList.from_df(df_abnl, sag_path, suffix='.npy')

In [14]:
il.items[0]

'../data/sagittal/train/0000.npy'

In [15]:
il

MR3DImageList (1250 items)
Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256)
Path: ../data/sagittal

In [16]:
sd = il.split_from_df(col=2)
sd

ItemLists;

Train: MR3DImageList (1130 items)
Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256)
Path: ../data/sagittal;

Valid: MR3DImageList (120 items)
Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256)
Path: ../data/sagittal;

Test: None

In [17]:
ll = sd.label_from_df(cols=1)
ll

LabelLists;

Train: LabelList (1130 items)
x: MR3DImageList
Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256)
y: CategoryList
1,1,1,1,1
Path: ../data/sagittal;

Valid: LabelList (120 items)
x: MR3DImageList
Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256),Image (51, 3, 256, 256)
y: CategoryList
0,0,0,0,0
Path: ../data/sagittal;

Test: None

In [18]:
# tfms = get_transforms()

In [19]:
bs = 1
data = ll.databunch(bs=bs)

In [20]:
cbfs = [partial(AvgStatsCallback,accuracy),
        CudaCallback]

In [21]:
mrnet_loss = partial(F.binary_cross_entropy_with_logits, pos_weight=weights[1])

class MRNetCallback(Callback):
    def begin_batch(self):
        self.run.xb = torch.squeeze(self.xb, dim=0)
        
    def after_pred(self):
        self.run.pred = torch.squeeze(self.pred, dim=0)
        self.run.yb = self.yb.float()


In [22]:
def mrnet_learner(data, lr, loss_func=mrnet_loss, cb_funcs=None, opt_func=optim.Adam):
    model = MRNet()
    return Learner(model, data, loss_func, lr=lr, cb_funcs=cb_funcs, opt_func=opt_func)


In [23]:
sched = combine_scheds([0.3, 0.7], [sched_cos(0.3, 0.6), sched_cos(0.6, 0.2)]) 

In [24]:
cbfs += [Recorder,
         partial(ParamScheduler, 'lr', sched),
         ProgressCallback,
         MRNetCallback] 

In [25]:
learn = mrnet_learner(data, 1e-5, cb_funcs=cbfs)

In [26]:
# learn.fit(1)