In [None]:
import torch
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from torch.utils.data import Dataset, DataLoader
import os
from sklearn.utils import shuffle
from tqdm import tqdm
from typing import Tuple, List, Type, Dict, Any

In [None]:
import cv2


In [None]:
from queue import Empty, Queue
from threading import Thread
import threading
#Библиотеки для потоков

In [None]:
# augmentation library
from imgaug.augmentables import Keypoint, KeypointsOnImage
import imgaug.augmenters as iaa 
#import accimage

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
with open('/content/drive/MyDrive/geo_kaggle/index.pkl', 'rb') as f:
    data_index = pickle.load(f)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print('Using GPU', f'({torch.cuda.get_device_name(), torch.cuda.get_device_properties(device)})')
else:
    device = torch.device('cpu')
    print('Using CPU')


In [None]:
class thread_killer(object):    
    """Boolean object for signaling a worker thread to terminate
    Once a thread is launched, it should be terminated at some moment.
    In case the function of this thread is an infinite loop, one needs a mutex
    for signaling a worker thread to break the loop.
    The fuction will return, and the thread will be terminated.
    """
    
    def __init__(self):
        self.to_kill = False

    def __call__(self):
        return self.to_kill

    def set_tokill(self, tokill):
        self.to_kill = tokill

In [None]:
def threaded_batches_feeder(tokill, batches_queue, dataset_generator):
    """
    Threaded worker for pre-processing input data.
    tokill (thread_killer): an object that indicates whether a thread should be terminated
    dataset_generator (Dataset): training/validation data generator
    batches_queue (Queue): a limited size thread-safe Queue instance for train/validation data batches
    """
    while tokill() == False:
        for sample_batch in dataset_generator:
            
            batches_queue.put(sample_batch, block=True)
            
            if tokill() == True:
                return

In [None]:
def threaded_cuda_batches(tokill, cuda_batches_queue, batches_queue):
    """
    Thread worker for transferring pytorch tensors into GPU. 
    batches_queue (Queue): the queue that fetches numpy cpu tensors.
    cuda_batches_queue (Queue): the queue receiving numpy cpu tensors and transfering them to GPU memory.
    """
    while tokill() == False:
        sample_batch,labels,ids = batches_queue.get(block=True)
        sample_batch = Variable(sample_batch).to(device)
        labels = labels.to(device)
        ids = ids.to(device)
        
        cuda_batches_queue.put((sample_batch,labels,ids), block=True)
        if tokill() == True:
            return

In [None]:
class Threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return next(self.it)

def get_objects_id(objects_count):
    """Cyclic generator of paths indices"""
    current_objects_id = 0
    while True:
        yield current_objects_id
        current_objects_id  = (current_objects_id + 1) % objects_count

In [None]:
class SkyDataset(Dataset):    
    def __init__(self, 
                 data_frame, 
                 root_dir, 
                 transform=None, 
                 batch_size = 8, 
                 augment = True,
                 seq = iaa.Sequential([iaa.GaussianBlur(sigma=(0, 5))],random_order=True), 
                 train = True, 
                 target = True):
        """
        Args:
            pkl_file (string): Path to the pkl file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            batch_size (int, optional): batch size
        """
        self.is_train = train
        self.with_target = target
        self.sky_data = data_frame
        self.root_dir = os.path.abspath(root_dir)
        self.transform = transform
        
        self.batch_size = batch_size
        
        self.objects_id_generator = Threadsafe_iter(get_objects_id(self.sky_data.shape[0]))
        
        self.lock = threading.Lock()
        self.yield_lock = threading.Lock()
        self.init_count = 0
        self.augment = augment
        self.cache = {}
        
        if self.augment:
            # instantiate augmentations
            self.seq = seq
        

       


    def __len__(self):                        
        return self.sky_data.shape[0]
    

    def shuffle(self):
        self.sky_data = shuffle(self.sky_data).reset_index(drop=True)
    
    def __iter__(self):
        while True:
            with self.lock:
                if (self.init_count == 0):
                    if self.is_train:
                        self.shuffle()
                    self.imgs = []
                    self.labels = []
                    self.ids =[]
                    self.init_count = 1
            
            
            for obj_id in self.objects_id_generator:
                add_str =''
                if(self.with_target and self.sky_data.iloc[obj_id]['mission'] == 'AI49'):
                    add_str = 'ai49-'
                
                
                
                img_name = os.path.join(self.root_dir, self.sky_data.iloc[obj_id]['mission'], 'snapshots', add_str + 'snapshots-'+str((self.sky_data.iloc[obj_id]['observations_dt']).date()), self.sky_data.iloc[obj_id]['jpg_filename'])
                mask_name = os.path.join(self.root_dir, self.sky_data.iloc[obj_id]['mission'], 'masks', 'mask-id'+str(self.sky_data.iloc[obj_id]['devID'])+'.png') 
                if (self.with_target):
                  label = int(self.sky_data.iloc[obj_id]['observed_TCC'])      
                
                  
               
                #Очень затратная шляпа по времени
               # mask = plt.imread(mask_name)
               # mask = np.where(mask == 255, np.ones_like(mask),mask*0)
                       
                
                
                img = plt.imread(img_name)
                img = torchvision.transforms.Compose([
                                                      torchvision.transforms.ToPILImage(), 
                                                     torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Resize([256, 256])
                                                    ])(img)
                if self.transform:
                    img = self.transform(img)
                      

                
                img= img.numpy()
                

                                  
                if self.augment:
                    img = self.seq(images = img)
                
                
                
                    # Concurrent access by multiple threads to the lists below
                with self.yield_lock:
                    if (len(self.imgs)) < self.batch_size:
                        self.imgs.append(img)
                        if self.with_target:
                            self.labels.append(label)
                        if self.with_target == False :
                            self.ids.append(obj_id)
                    if (obj_id + 1 == len(self)):
                        yield (torch.Tensor(np.array(self.imgs)),(torch.Tensor(self.labels)).type(torch.LongTensor), torch.Tensor(self.ids))
                        self.imgs = []
                        self.labels = []
                        self.jpg_names =[]
                        break    
                    if len(self.imgs) % self.batch_size == 0:
                     
                        yield (torch.Tensor(np.array(self.imgs)),(torch.Tensor(self.labels)).type(torch.LongTensor), torch.Tensor(self.ids))
                        self.imgs = []
                        self.labels = []
                        self.ids =[]

                    
            # At the end of an epoch we re-init data-structures
            with self.lock:
                if self.is_train:
                    self.sky_data = shuffle(self.sky_data)
                self.init_count = 0

In [None]:
def train_single_epoch(model: torch.nn.Module,
                       optimizer: torch.optim.Optimizer, 
                       loss_function: torch.nn.Module, 
                       STEPS_PER_EPOCH,
                       train_cuda_batches_queue,
                       data_len):
    
    model.train()
    loss_sum = 0

    for image_batch in tqdm(range(STEPS_PER_EPOCH), total=STEPS_PER_EPOCH):
      
        x,y, ids = train_cuda_batches_queue.get(block = True)
        
        model.zero_grad()
        hyp = model(x)
       
      
        loss = loss_function(hyp, y)
        loss.backward()
        loss_sum += loss
        
        optimizer.step()

    
    return loss_sum/float(data_len)

In [None]:
@torch.no_grad()
def validate_single_epoch(model: torch.nn.Module,
                          loss_function: torch.nn.Module,                          
                          STEPS_PER_EPOCH,
                          test_cuda_batches_queue,
                          data_len):
    model.eval()
    loss_sum = 0
    accuracy = 0
    
    for image_batch in range(STEPS_PER_EPOCH):
        
        x,y,ids = test_cuda_batches_queue.get(block = True)

        hyp = model(x)
        loss = loss_function(hyp, y)
        loss_sum += loss

        y_pred = hyp.argmax(dim = 1, keepdim = True).to(device)
    
        accuracy += y_pred.eq(y.view_as(y_pred)).sum().item()

    loss_avr = loss_sum / float(data_len)
    accuracy_avr = 100 * accuracy / float(data_len)
    
    return {'loss' : loss_avr.item(), 'accuracy' : accuracy_avr}

In [None]:
def ploting_curves(loss, best_epoch):
    """
    Plot loss evolution on training and validation sets
    """
    # Plot learning loss curve
    plt.plot(loss['train'], label = 'Training set')
    plt.plot(loss['valid'], label = 'Val set')
    plt.axvline(best_epoch, color = 'r', ls = '--', label = 'Best model')
    plt.title('Loss evolution')
    plt.xlabel('epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [None]:
import math
def train_model(model: torch.nn.Module, 
                train_data,
                test_data,
                loss_function: torch.nn.Module = torch.nn.CrossEntropyLoss(),
                optimizer_class: Type[torch.optim.Optimizer] = torch.optim,
                optimizer_params: Dict = {},
                initial_lr = 0.01,
                lr_scheduler_class: Any = torch.optim.lr_scheduler.ReduceLROnPlateau,
                lr_scheduler_params: Dict = {},
                batch_size = 16,
                s_eq = iaa.Sequential([iaa.GaussianBlur(sigma=(0, 5))],random_order=True),
                max_epochs = 1000,
                early_stopping_patience = 20):
    # set to training mode
  
    # Here we instantiate queues and mutexes, and launch the threads that will preprocess the data and send it into GPU

    
    SkyData_train = SkyDataset(train_data, root_dir = '/content/drive/MyDrive/geo_kaggle', batch_size= batch_size, seq = s_eq)
    SkyData_test = SkyDataset(test_data, root_dir = '/content/drive/MyDrive/geo_kaggle', batch_size= batch_size, augment= False, train = False)
    
    STEPS_PER_EPOCH_TRAIN = math.ceil(len(SkyData_train)/ float(batch_size)) 
    STEPS_PER_EPOCH_TEST = math.ceil(len(SkyData_test)/ float(batch_size)) 

    #Настроики на трайн
    train_batches_queue_length = min(STEPS_PER_EPOCH_TRAIN, 3)    
    train_batches_queue = Queue(maxsize=train_batches_queue_length)
    train_cuda_batches_queue = Queue(maxsize=4)
    train_thread_killer = thread_killer()
    train_thread_killer.set_tokill(False)
    train_preprocess_workers = 24

    for _ in range(train_preprocess_workers):
        thr = Thread(target=threaded_batches_feeder, args=(train_thread_killer, train_batches_queue, SkyData_train))
        thr.start()

    train_cuda_transfers_thread_killer = thread_killer()
    train_cuda_transfers_thread_killer.set_tokill(False)
    train_cudathread = Thread(target=threaded_cuda_batches, args=(train_cuda_transfers_thread_killer, train_cuda_batches_queue, train_batches_queue))
    train_cudathread.start()

    #Настроики на тест  
    test_batches_queue_length = min(STEPS_PER_EPOCH_TEST, 3)    
    test_batches_queue = Queue(maxsize=test_batches_queue_length)
    test_cuda_batches_queue = Queue(maxsize=4)
    test_thread_killer = thread_killer()
    test_thread_killer.set_tokill(False)
    test_preprocess_workers = 8

    for _ in range(test_preprocess_workers):
        thr = Thread(target=threaded_batches_feeder, args=(test_thread_killer, test_batches_queue, SkyData_test))
        thr.start()

    test_cuda_transfers_thread_killer = thread_killer()
    test_cuda_transfers_thread_killer.set_tokill(False)
    test_cudathread = Thread(target=threaded_cuda_batches, args=(test_cuda_transfers_thread_killer, test_cuda_batches_queue, test_batches_queue))
    test_cudathread.start()

    # Everything is ready for the training
  
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr, **optimizer_params)
    lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_params)
    
    best_val_loss = None
    best_epoch = None
    loss_list = {'train': list(), 'valid': list()}
    

    for epoch in range(max_epochs):
        
        print(f'Epoch {epoch}')
    
        train_loss =  train_single_epoch(model, optimizer, loss_function,STEPS_PER_EPOCH_TRAIN, train_cuda_batches_queue, len(SkyData_train))
        
        print('Validating epoch\n')

        val_metrics = validate_single_epoch(model, loss_function, STEPS_PER_EPOCH_TEST, test_cuda_batches_queue, len(SkyData_test))
        loss_list['valid'].append(val_metrics['loss'])
        print(f'Validation metrics: \n{val_metrics}')

        lr_scheduler.step(val_metrics['loss'])
        
        if best_val_loss is None or best_val_loss > val_metrics['loss']:
            print(f'Best model yet, saving')
            best_val_loss = val_metrics['loss']
            best_epoch = epoch
            torch.save(model, './best_model.pth')
            torch.save(model, '/content/drive/MyDrive/geo_kaggle/best_model_disk.pth')
            
        if epoch - best_epoch > early_stopping_patience:
            print('Early stopping triggered')
            ploting_curves(loss_list,best_epoch)
            break
    

    train_thread_killer.set_tokill(True)
    train_cuda_transfers_thread_killer.set_tokill(True)
    for _ in range(train_preprocess_workers):
        try:
            # Enforcing thread shutdown
            train_batches_queue.get(block=True, timeout=1)
            train_cuda_batches_queue.get(block=True, timeout=1)
        except Empty:
            pass

    test_thread_killer.set_tokill(True)
    test_cuda_transfers_thread_killer.set_tokill(True)
    for _ in range(test_preprocess_workers):
        try:
            # Enforcing thread shutdown
            test_batches_queue.get(block=True, timeout=1)
            test_cuda_batches_queue.get(block=True, timeout=1)
        except Empty:
            pass

In [None]:

def GetTarget(data, model, batch = 100):
    SkyData_test = SkyDataset(data, root_dir = '/content/drive/MyDrive/geo_kaggle_test', batch_size= batch, augment= False, train = False, target = False)
    df =  pd.DataFrame(columns=['jpg_filename','TCC'])
    STEPS_PER_EPOCH_TEST = math.ceil(len(SkyData_test)/ float(batch)) 

    size = 0
    #Настроики на тест  
    test_batches_queue_length = min(STEPS_PER_EPOCH_TEST, 4)    
    test_batches_queue = Queue(maxsize=test_batches_queue_length)
    test_cuda_batches_queue = Queue(maxsize=4)
    test_thread_killer = thread_killer()
    test_thread_killer.set_tokill(False)
    test_preprocess_workers = 32

    for _ in range(test_preprocess_workers):
        thr = Thread(target=threaded_batches_feeder, args=(test_thread_killer, test_batches_queue, SkyData_test))
        thr.start()

    test_cuda_transfers_thread_killer = thread_killer()
    test_cuda_transfers_thread_killer.set_tokill(False)
    test_cudathread = Thread(target=threaded_cuda_batches, args=(test_cuda_transfers_thread_killer, test_cuda_batches_queue, test_batches_queue))
    test_cudathread.start()

    
    model.eval()
   
    for image_batch in tqdm(range(STEPS_PER_EPOCH_TEST), total=STEPS_PER_EPOCH_TEST):
        
        x,y,ids = test_cuda_batches_queue.get(block = True)

        hyp = model(x)
  
        y_pred = hyp.argmax(dim = 1, keepdim = True).to(device)

        ids = (ids).tolist()
        y_pred = (y_pred).tolist()
        for i in range(len(ids)):
            df.loc[size] = [data.iloc[int(ids[i])]['jpg_filename'], y_pred[i][0]]
            size += 1
        
    test_thread_killer.set_tokill(True)
    test_cuda_transfers_thread_killer.set_tokill(True)
    for _ in range(test_preprocess_workers):
        try:
            # Enforcing thread shutdown
            test_batches_queue.get(block=True, timeout=1)
            test_cuda_batches_queue.get(block=True, timeout=1)
        except Empty:
            pass

    return df.drop_duplicates()
         
    
        

    

In [None]:
class Pe(torch.nn.Module):
    
    def __init__(self, 
                 input_resolution: Tuple[int, int] = (512, 512),
                 input_channels: int = 1, 
                 hidden_layer_features: List[int] = [256, 256, 256],
                 activation: Type[torch.nn.Module] = torch.nn.ReLU,
                 num_classes: int = 9):
        
        super().__init__()

        self.input_resolution = input_resolution
        self.input_channels = input_channels
        self.hidden_layer_features = hidden_layer_features
        self.activation = activation
        self.num_classes = num_classes
        
        self.conv1 = torch.nn.Conv2d(3, 3, 32)
        self.conv2 = torch.nn.Conv2d(3, 3, 32, 4)
        self.conv3 = torch.nn.Conv2d(3, 3, 7, 7)
        
        self.fc1 = torch.nn.Linear(7*7*3, 9)
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = F.relu(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        
        x = self.conv3(x)
        x = F.relu(x)
        
        x = x.view(-1, 7*7*3)
        x = self.fc1(x)
        
        output = F.log_softmax(x, dim = 1)
        
        return output

In [None]:
model = Pe()
model.to(device)
print(model)
print('Total number of trainable parameters', 
      sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
#Если делать умную выборку, то здесь в pd
data = pd.DataFrame(data_index)
data.shape

In [None]:
df= data[data['observed_TCC'] == 8 ].sample(10000)

In [None]:
data = pd.concat([df, data[data['observed_TCC'] != 8 ]], ignore_index=True).reset_index(drop = True)
data.shape

In [None]:
plt.figure(figsize = (10, 10))
sns.histplot(data, x = 'observed_TCC', hue = 'mission', multiple = 'stack')

In [None]:
data = shuffle(data).reset_index(drop = True)

In [None]:
test_data = data.loc[60000:67999,:].reset_index(drop=True)
train_data = data.loc[:59999,:].reset_index(drop=True)

In [None]:
train_model(model, 
            train_data,
            test_data,
            loss_function=torch.nn.CrossEntropyLoss(), 
            initial_lr=0.0001,
            batch_size= 100)


In [None]:
#the_model = torch.load('/content/drive/MyDrive/geo_kaggle/best_model_1.pth')

In [None]:
#df = GetTarget(data,the_model)

In [None]:
#df

In [None]:
#df.to_csv('second_try.csv', sep=',' , index=False, header = None)

Из-за параллельных вычислений некоторые объекты не попадают в модель, вот их имена 

In [None]:
#pd.concat([d, f]).drop_duplicates(keep=False)