In [14]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import torch

from sklearn.metrics import accuracy_score

In [24]:
TRAIN_SIZE = 1000
TEST_SIZE = 200

BATCH_SIZE = 32
EPOCHS = 100

def init_weights(m):
    if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d):
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.01)

In [16]:
def create_sample(width, height, movie_length, direction):
  frames = list()

  last_point_in_frame = [np.random.randint(height), np.random.randint(width)]
  for time_index in range(movie_length):
    frame = np.ones((height, width))
    frame[last_point_in_frame[0], last_point_in_frame[1]] = 0

    new_width = last_point_in_frame[1]+direction
    new_height = last_point_in_frame[0]+(np.random.randint(0, 3) - 1)

    if new_width>=width:
      new_width = 0
    if new_width<0:
      new_width = width - 1

    if new_height>=height:
      new_height = 0

    if new_height<0:
      new_height = height - 1

    last_point_in_frame[1] = new_width
    last_point_in_frame[0] = new_height


    frames.append(frame)

  frames = np.array(frames)

  return frames


In [17]:
train_raw_data = list()
train_labels = list()

test_raw_data = list()
test_labels = list()

for _ in range(TRAIN_SIZE):
  if np.random.randn()>0.5:
    direction = 1
    label = 1
  else:
    direction = -1
    label = 0


  frames = create_sample(width=15, height=8, movie_length=10, direction=direction)

  train_raw_data.append(frames)
  train_labels.append(label)

train_raw_data = np.array(train_raw_data)
train_labels = np.array(train_labels)

for _ in range(TEST_SIZE):
  if np.random.randn()>0.5:
    direction = 1
    label = 1
  else:
    direction = -1
    label = 0


  frames = create_sample(width=15, height=8, movie_length=10, direction=direction)

  test_raw_data.append(frames)
  test_labels.append(label)


test_raw_data = np.array(test_raw_data)
test_labels = np.array(test_labels)


In [18]:
class MovieDataset(torch.utils.data.Dataset):
  def __init__(self, dataset, labels):
    super(MovieDataset, self).__init__()

    self.dataset = np.expand_dims(dataset, axis=1).astype(np.float32)
    self.labels = np.expand_dims(labels, axis=1).astype(np.float32)


  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, index):
    data = self.dataset[index]
    label = self.labels[index]

    return data, label

In [19]:
class MovieClassifier(torch.nn.Module):
  def __init__(self):
    super(MovieClassifier, self).__init__()

    self.lstm_hidden_size = 32

    self.conv_layers = torch.nn.ModuleList()

    input_channel = 1
    for channel_index in range(4):
      if channel_index<1:
        self.conv_layers.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=2*input_channel, kernel_size=(2, 2), stride=(2, 2)))
      else:
        self.conv_layers.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=2*input_channel, kernel_size=(2, 2), stride=(1, 1)))
      input_channel*=2

    self.conv_layers.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=4, kernel_size=(1, 1)))
    ######################################
    self.lstm_layer_1 = torch.nn.LSTM(16, self.lstm_hidden_size, 1, batch_first=True)
    ######################################
    self.linear_classifier_l1 = torch.nn.Linear(self.lstm_hidden_size, 16)
    self.linear_classifier_l2 = torch.nn.Linear(16, 8)
    self.linear_classifier_l3 = torch.nn.Linear(8, 1)




  def forward(self, x):
    h0 = torch.zeros(1, x.size(0), self.lstm_hidden_size)
    c0 = torch.zeros(1, x.size(0), self.lstm_hidden_size)

    for frame_index in range(x.shape[2]):
      frame = x[:, 0, frame_index, :, :].unsqueeze(1)
      for layer_index, layer in enumerate(self.conv_layers):
        frame = layer(frame)

        if layer_index >= 2:
          frame = torch.nn.functional.relu(frame)

      frame = frame.view(x.size(0), 1, -1)
      lstm_out, (h0, c0) = self.lstm_layer_1(frame, (h0, c0))

    lstm_out = lstm_out.view(x.size(0), -1)

    classifier_output = self.linear_classifier_l1(lstm_out)
    classifier_output = self.linear_classifier_l2(classifier_output)
    classifier_output = self.linear_classifier_l3(classifier_output)
    classifier_output = torch.nn.functional.sigmoid(classifier_output)

    return classifier_output



In [58]:
train_torch_dataset = MovieDataset(dataset=train_raw_data, labels=train_labels)
train_dataloader = torch.utils.data.DataLoader(dataset=train_torch_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_torch_dataset = MovieDataset(dataset=test_raw_data, labels=test_labels)
test_dataloader = torch.utils.data.DataLoader(dataset=test_torch_dataset, batch_size=BATCH_SIZE)

model = MovieClassifier()

model.apply(init_weights)

loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [59]:
for epoch in range(EPOCHS):
  print(f"Processing epoch {epoch+1}")
  for data, labels in train_dataloader:
    model.zero_grad()

    predictions = model(data)

    loss = loss_function(predictions, labels)
    loss.backward()
    optimizer.step()

  print(f"Loss: {loss.item()}")


Processing epoch 1
Loss: 0.8248739242553711
Processing epoch 2
Loss: 0.5661195516586304
Processing epoch 3
Loss: 0.5659032464027405
Processing epoch 4
Loss: 0.5706061124801636
Processing epoch 5
Loss: 0.5639681816101074
Processing epoch 6
Loss: 0.570500910282135
Processing epoch 7
Loss: 0.5644223093986511
Processing epoch 8
Loss: 0.440577894449234
Processing epoch 9
Loss: 0.46991297602653503
Processing epoch 10
Loss: 0.6899201273918152
Processing epoch 11
Loss: 0.5675331354141235
Processing epoch 12
Loss: 0.44704771041870117
Processing epoch 13
Loss: 0.564027726650238
Processing epoch 14
Loss: 0.6651592254638672
Processing epoch 15
Loss: 0.8074864745140076
Processing epoch 16
Loss: 0.5664162635803223
Processing epoch 17
Loss: 0.5656944513320923
Processing epoch 18
Loss: 0.6735174655914307
Processing epoch 19
Loss: 0.45761409401893616
Processing epoch 20
Loss: 0.45167452096939087
Processing epoch 21
Loss: 0.6753695607185364
Processing epoch 22
Loss: 0.791152834892273
Processing epoch 23

In [60]:
predictions_list = None
true_labels_list = None
with torch.no_grad():
  for data, labels in test_dataloader:
    predictions = torch.round(model(data))

    if predictions_list is None:
      predictions_list = predictions.detach().numpy()
      true_labels_list = labels.detach().numpy()
    else:
      predictions_list = np.append(predictions_list, predictions.detach().numpy(), axis=0)
      true_labels_list = np.append(true_labels_list, labels.detach().numpy(), axis=0)



print(f"Accuracy is : {accuracy_score(predictions_list, true_labels_list):.2f}")

Accuracy is : 1.00
