# Lecture 60: Activity recognition using CNN-LSTM
## 60b: Feature extraction using CNN

In [None]:
import os
import pickle
import torch
import numpy as np
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms,datasets

print(torch.__version__) # This code has been updated for PyTorch 1.0.0

In [None]:
# Check availability of GPU

use_gpu = torch.cuda.is_available()
# use_gpu = False # Uncomment in case of GPU memory error
if use_gpu:
    print('GPU is available!')
    device = "cuda"
else:
    print('GPU is not available!')
    device = "cpu"

In [None]:
# Load train-test list
with open('trainList_5class.pckl','rb') as f:
    trainList = pickle.load(f)
with open('testList_5class.pckl','rb') as f:
    testList = pickle.load(f)

In [None]:
classes = []
for item in trainList:
    c = item.split('_')[1]
    if c not in classes:
        classes.append(c)
print(classes)

### Initialize network and load trained weights

In [None]:
net = models.resnet18()
net.fc = nn.Linear(512,5)
# Loading saved states
net.load_state_dict(torch.load('resnet18Pre_fcOnly5class_ucf101_10adam_1e-4_b128.pt'))
# Removing fully connected layer for feature extraction
model = nn.Sequential(*list(net.children())[:-1]).to(device)

### Feature extraction

In [None]:
# PIL image to tensor transformation
data_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),            
        transforms.ToTensor()
    ])

In [None]:
framePath = 'frames/'
for item in trainList:
    cName = item.split('_')[1]
    srcPath = framePath+cName+'/'+item    
    fNames = os.listdir(srcPath)
    # filename template
    fTemplate = fNames[0].split('_')
    fCount = len(fNames)
    for fNum in range(fCount):
        fileName = fTemplate[0]+'_'+fTemplate[1]+'_'+fTemplate[2]+'_'+fTemplate[3]+'_'+str(fNum+1)+'.jpg'
        if os.path.exists(srcPath+'/'+fileName):
            # Loading image
            img = Image.open(srcPath+'/'+fileName)
            # Transform to tensor
            imgTensor = data_transforms(img).unsqueeze(0)
            inp = imgTensor.to(device)
            # Feed-forward through model+stack features for each video
            if fNum == 0:
                out = model(inp)               
                if use_gpu:
                    out = out.cpu()
                out = out.view(out.size()[0],-1).data             
            else:
                out1 = model(inp)               
                if use_gpu:
                    out1 = out1.cpu()
                out1 = out1.view(out1.size()[0],-1).data.cpu()                
                out = torch.cat((out,out1),0)
        else:
            print(fileName+ ' missing!')       
    # out dimension -> frame count x 512
    featSavePath = 'ucf101_resnet18Feat/train/'+cName # Directory for saving features
    if not os.path.exists(featSavePath):
        os.makedirs(featSavePath)
    torch.save(out,os.path.join(featSavePath,item+'.pt'))   
    

In [None]:
framePath = 'frames/'
for item in testList:
    cName = item.split('_')[1]
    srcPath = framePath+cName+'/'+item    
    fNames = os.listdir(srcPath)
    fTemplate = fNames[0].split('_')
    fCount = len(fNames)
    for fNum in range(fCount):
        fileName = fTemplate[0]+'_'+fTemplate[1]+'_'+fTemplate[2]+'_'+fTemplate[3]+'_'+str(fNum+1)+'.jpg'
        if os.path.exists(srcPath+'/'+fileName):
            img = Image.open(srcPath+'/'+fileName)
            imgTensor = data_transforms(img).unsqueeze(0)
            inp = imgTensor.to(device)
            if fNum == 0:
                out = model(inp) 
                if use_gpu:
                    out = out.cpu()
                out = out.view(out.size()[0],-1).data               
            else:
                out1 = model(inp) 
                if use_gpu:
                    out1 = out1.cpu()
                out1 = out1.view(out1.size()[0],-1).data                
                out = torch.cat((out,out1),0)
        else:
            print(fileName+ ' missing!')
      
    featSavePath = 'ucf101_resnet18Feat/test/'+cName
    if not os.path.exists(featSavePath):
        os.makedirs(featSavePath)
    torch.save(out,os.path.join(featSavePath,item+'.pt'))   