# Neuronal Cell Classification

We train a classification model for predicting the `cell_type` for Sartorius Cell Instance Segmentation Data. This could be helpful in the instance segmentation pipeline.

In [None]:
import fastai
from fastai.vision.all import *

In [None]:
SEED = 3011

def fix_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
fix_seeds(SEED)

In [None]:
TRAIN_CSV = "../input/sartorius-cell-instance-segmentation/train.csv"
TRAIN_PATH = "../input/sartorius-cell-instance-segmentation/train"
TEST_PATH = "../input/sartorius-cell-instance-segmentation/test"

In [None]:
df_train = pd.read_csv(TRAIN_CSV)

In [None]:
img_df = df_train[['id', 'cell_type']].drop_duplicates().reset_index(drop = True)

With this, we have a dataframe containing information about the image id and it's class.

In [None]:
img_df.head()

We use `fastai`'s datablock API to quickly create the dataloader.

In [None]:
dblock = DataBlock(
    blocks = (ImageBlock, CategoryBlock),
    get_x = ColReader('id', pref = "../input/sartorius-cell-instance-segmentation/train/", suff = '.png'),
    get_y = ColReader('cell_type')
)
dls = dblock.dataloaders(img_df, bs = 32, num_workers = 4)
dls.show_batch(figsize = (30, 22))

Using the default resnet18 model is enough for our needs here.

In [None]:
learn = cnn_learner(dls, resnet18, metrics = accuracy)

In [None]:
learn.cuda();

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(5, 1e-3)

In [None]:
learn.export("cell-classification-learner.pkl")

In [None]:
learn2 = load_learner("cell-classification-learner.pkl")

In [None]:
learn2.cpu();

## References

`fastai` was used for creating and training the model, and it provides a quick and flexible way of creating dataloaders.  
More about `fastai` vision can be found [here](https://docs.fast.ai/tutorial.vision.html).