In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision

import PIL

import numpy as np
import matplotlib.pyplot as plt

import os
import re

In [None]:
from tqdm.auto import tqdm

In [None]:
f_path_list = ["data/labeldata2/", "data/labeldata3/"]

In [None]:
IMAGE_SIZE = 224

In [None]:
class AllTangramImages(Dataset):
    def __init__(self, folder_path_list,  transforms=None):
        self.data_root_folders = folder_path_list
        self.image_list  = self.load_image_list()
        self.transforms = transforms
        
    def load_image_list(self):
        image_list = []
        for folder_path in self.data_root_folders:
            for time_folder in os.listdir(folder_path):
                image_folder = os.path.join(folder_path, time_folder)
                for file_name in os.listdir(image_folder):
                    image_list.append(os.path.join(image_folder, file_name))
        
        return image_list
    
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, index):
        single_image_path = self.image_list[index]
        # Open image
        img = PIL.Image.open(single_image_path).convert("L")
        img = PIL.ImageOps.invert(img)
        #img = resize(img,(IMAGE_SIZE,IMAGE_SIZE))
        if self.transforms is not None:
            img = self.transforms(img)
            
        score = float(re.split(r"(\\|\.|/)",single_image_path)[-3])
        return img, score / 8.0
        

In [None]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((IMAGE_SIZE,IMAGE_SIZE), interpolation = PIL.Image.NEAREST),
    torchvision.transforms.RandomAffine(degrees = 90, translate = (0.2,0.2), scale = (0.6,1)),
    torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    # torchvision.transforms.RandomRotation(20, resample=PIL.Image.BILINEAR)
    torchvision.transforms.ToTensor(),
])

In [None]:
ati = AllTangramImages(f_path_list, transforms)

In [None]:
len(ati)

In [None]:
plt.imshow(ati[-300][0][0], cmap = "gray")

In [None]:
from efficientnet_pytorch import EfficientNet

In [None]:
class EfficientNeuralNetwork(nn.Module):
    def __init__(self):
        super(EfficientNeuralNetwork, self).__init__()
        self.efficient = EfficientNet.from_name('efficientnet-b0', in_channels = 1, num_classes = 0) #!!!!注意一下

        self.scores_head = nn.Sequential(
            nn.Linear(1280, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs, params=None):
        features = self.efficient.extract_features(inputs)
        # print(features.shape)
        features = features.view((features.size(0), -1))
        # print(features.shape)
        features = torch.mean(features.view(features.size(0), features.size(0), -1), dim = 2) #global average pooling
        scores = self.scores_head(features)
        return scores

In [None]:
model = EfficientNeuralNetwork()

In [None]:
if torch.cuda.is_available():
    model = model.cuda()

In [None]:
train_dataset_len = int(0.9 * len(ati))
test_dataset_len = len(ati) - train_dataset_len

In [None]:
train_dataset, test_dataset = torch.utils.data.random_split(ati,[train_dataset_len, test_dataset_len])

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=0)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 20

In [None]:
for epoch in range(epochs):
    loss_epoch_train = []
    loss_epoch_test = []
    count = 0
    model.train()
    for batch_inputs, batch_scores in tqdm(train_dataloader):
        #print(batch_inputs.shape)
        if torch.cuda.is_available():
            batch_inputs = batch_inputs.to("cuda")
            batch_scores = batch_scores.to("cuda")
        
        pred_scores = model(batch_inputs).view(-1)
        
        loss = torch.sum((pred_scores - batch_scores)**2)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_epoch_train.append(loss.item())
        
    model.eval()
    for batch_inputs, batch_scores in tqdm(test_dataloader):
        #print(batch_inputs.shape)
        if torch.cuda.is_available():
            batch_inputs = batch_inputs.to("cuda")
            batch_scores = batch_scores.to("cuda")
        
        pred_scores = model(batch_inputs).view(-1)
        
        loss = torch.sum((pred_scores - batch_scores)**2)
        
        loss_epoch_test.append(loss.item())
        
    print("epoch: {} avg loss train {:.2f} avg loss test {:.2f} ".format(epoch, np.mean(loss_epoch_train), np.mean(loss_epoch_test)))

In [None]:
torch.save(model.state_dict(), "11_4_eff_pre_train.pth")