In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import torch
import random
from torchvision import datasets, transforms

In [2]:
FIT_SURVIVAL_RATE = 0.5
UNFIT_SURVIVAL_RATE = 0.2
MUTATION_RATE = 0.1
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
kwargs = {'batch_size': 64, 'shuffle': True, 'num_workers': 2, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(trainset, **kwargs)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(testset, **kwargs)
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import torch
import random

SEARCH_SPACE = {
    'k_size_a': [1, 3, 5],
    'k_size_b': [1, 3, 5],
    'out_channels_a': [8, 16, 32, 64],
    'out_channels_b': [8, 16, 32, 64],
    'include_pool_a': [True, False],
    'include_pool_b': [True, False],
    'pool_type_a': ['max_pooling','avg_pooling'],
    'pool_type_b': ['max_pooling','avg_pooling'],
    'activation_type_a': ['relu', 'tanh', 'elu', 'selu'],
    'activation_type_b': ['relu', 'tanh', 'elu', 'selu'], 
    'include_b': [True, False],
    'include_BN_a': [True, False],
    'include_BN_b': [True, False],
    'skip_connection': [True, False],
}

INIT_SEARCH_SPACE = {
    'k_size_a': [1, 3, 5],
    'k_size_b': [1, 3, 5],
    'out_channels_a': [8, 16, 32, 64],
    'out_channels_b': [8, 16, 32, 64],
    'include_pool_a': [True, False],
    'include_pool_b': [True, False],
    'pool_type_a': ['max_pooling','avg_pooling'],
    'pool_type_b': ['max_pooling','avg_pooling'],
    'activation_type_a': ['relu', 'tanh', 'elu', 'selu'],
    'activation_type_b': ['relu', 'tanh', 'elu', 'selu'], 
    'include_b': [False],
    'include_BN_a': [True, False],
    'include_BN_b': [True, False],
    'skip_connection': [True, False],

}

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class FinalModel(nn.Module):
    def __init__(self, chromosome):
        super().__init__()
        self.block = chromosome.model
        if chromosome.phase == 0:
            in_channels = 3
            out_channels = chromosome.genes['out_channels_b']
        else:
            if(chromosome.prev_best.genes['include_b']):
                in_channels = chromosome.prev_best.genes['out_channels_b']
            else:
                in_channels = chromosome.prev_best.genes['out_channels_a']
            if(chromosome.genes['include_b']):
                out_channels = chromosome.genes['out_channels_b']
            else:
                out_channels = chromosome.genes['out_channels_a']
        self.skip = nn.Conv2d(in_channels, out_channels, 1)

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(out_channels*chromosome.out_dimensions**2,10)

    def forward(self, x, chromosome):
        if chromosome.genes['skip_connection']:
            y = x
            if chromosome.phase != 0:
                y = chromosome.prev_best.model(x)
            y = self.skip(y)
        x=self.block(x)
        if chromosome.genes['skip_connection']:
            x = x + y
        x = self.fc(self.flatten(x))
        x = F.log_softmax(x, dim=1)
        return x

In [4]:
class Chromosome:
    def __init__(self,phase:int,prev_best,genes:dict,train_loader,test_loader):
        self.phase = phase
        self.prev_best = prev_best
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.genes = genes
        self.out_dimensions = prev_best.out_dimensions if phase!=0 else 32
        self.fitness = -1 
        self.model:nn.Module = self.build_model()
        self.train_loader = train_loader
        self.test_loader = test_loader
        if self.fitness==-1:
            self.fitness = self.fitness_function(train_loader,test_loader)

    def build_model(self)->nn.Module:
        if(self.prev_best!=None):
            prev_best_model:nn.Module = self.prev_best.model
        new_model_modules = []
        padding_size = 0
        if(self.genes['skip_connection']):
            padding_size = 16 if self.phase==0 else self.prev_best.out_dimensions//2
        if(self.out_dimensions<self.genes['k_size_a']):
            self.fitness = 0
            return nn.Sequential()
        
        if(self.phase!=0):
            # layer_a = nn.Conv2d(self.prev_best.genes['out_channels_b'] if self.prev_best.genes['include_b'] else self.prev_best.genes['out_channels_a'],self.genes['out_channels_a'],self.genes['k_size_a'],padding = self.genes['k_size_a']//2 if self.genes['skip_connection'] else 0)
            layer_a = nn.Conv2d(self.prev_best.genes['out_channels_b'] if self.prev_best.genes['include_b'] else self.prev_best.genes['out_channels_a'],self.genes['out_channels_a'],self.genes['k_size_a'],padding = 'same')
        else:
            # layer_a = nn.Conv2d(3,self.genes['out_channels_a'],self.genes['k_size_a'],padding = self.genes['k_size_a']//2 if self.genes['skip_connection'] else 0)
            layer_a = nn.Conv2d(3,self.genes['out_channels_a'],self.genes['k_size_a'],padding = 'same')
        # self.out_dimensions = (self.out_dimensions-self.genes['k_size_a']+1)
        new_model_modules.append(layer_a)
        if(self.genes['activation_type_a']=='relu'):
            new_model_modules.append(nn.ReLU())
        elif(self.genes['activation_type_a']=='elu'):
            new_model_modules.append(nn.ELU())
        elif(self.genes['activation_type_a']=='selu'):
            new_model_modules.append(nn.SELU())
        else:
            new_model_modules.append(nn.Tanh())
        if(self.genes['include_pool_a'] and not self.genes['skip_connection']):
            if(self.out_dimensions<2):
                self.fitness = 0
                return nn.Sequential()
            if(self.genes['pool_type_a']=='max_pooling'):
                new_model_modules.append(nn.MaxPool2d(2,2,padding = padding_size))
                # new_model_modules.append(nn.MaxPool2d(2,2,padding = 'same'))
                self.out_dimensions = self.out_dimensions//2
            elif(self.genes['pool_type_a']=='avg_pooling'):
                new_model_modules.append(nn.AvgPool2d(2,2,padding = padding_size))
                # new_model_modules.append(nn.AvgPool2d(2,2,padding = 'same'))
                self.out_dimensions = self.out_dimensions//2
            else:
                raise Exception('Invalid pool type (a layer)')
        
        if(self.genes['include_BN_a']):
            new_model_modules.append(nn.BatchNorm2d(self.genes['out_channels_a']))
        
        if(self.genes['include_b'] or self.phase==0):
            if(self.out_dimensions<self.genes['k_size_b']):
                self.fitness = 0
                return nn.Sequential()
            # layer_b = nn.Conv2d(self.genes['out_channels_a'],self.genes['out_channels_b'],self.genes['k_size_b'],padding = self.genes['k_size_b']//2 if self.genes['skip_connection'] else 0)
            layer_b = nn.Conv2d(self.genes['out_channels_a'],self.genes['out_channels_b'],self.genes['k_size_b'],padding = 'same')
            # self.out_dimensions = (self.out_dimensions-self.genes['k_size_b']+1)
            new_model_modules.append(layer_b)
            if(self.genes['activation_type_b']=='relu'):
                new_model_modules.append(nn.ReLU())
            elif(self.genes['activation_type_b']=='elu'):
                new_model_modules.append(nn.ELU())
            elif(self.genes['activation_type_b']=='selu'):
                new_model_modules.append(nn.SELU())
            else:
                new_model_modules.append(nn.Tanh())
            
            if(self.genes['include_pool_b'] and not self.genes['skip_connection']):
                if(self.out_dimensions<2):
                    self.fitness = 0
                    return nn.Sequential()
                if(self.genes['pool_type_b']=='max_pooling'):
                    new_model_modules.append(nn.MaxPool2d(2,2,padding = padding_size))
                    # new_model_modules.append(nn.MaxPool2d(2,2,padding = 'same'))
                    self.out_dimensions = self.out_dimensions//2
                elif(self.genes['pool_type_b']=='avg_pooling'):
                    new_model_modules.append(nn.AvgPool2d(2,2,padding = padding_size))
                    # new_model_modules.append(nn.AvgPool2d(2,2,padding = 'same'))
                    self.out_dimensions = self.out_dimensions//2
                else:
                    raise Exception('Invalid pool type (b layer)')
                
            if(self.genes['include_BN_b']):
                new_model_modules.append(nn.BatchNorm2d(self.genes['out_channels_b']))
        if(self.phase!=0):
            new_model = nn.Sequential(prev_best_model,*new_model_modules)
        else:
            new_model = nn.Sequential(*new_model_modules)
        if(self.genes['skip_connection']):
            self.out_dimensions = 32 if self.phase==0 else self.prev_best.out_dimensions
        # print(new_model)
        return new_model            

    def fitness_function(self,train_loader,test_loader)->float:
        
        new_model = FinalModel(self)
        #Training loop
        optimizer = optim.Adam(new_model.parameters(), lr=0.001)
        criterion = F.nll_loss
        new_model.to(self.device)
        num_epochs = 5
        for epoch in range(num_epochs):
            pbar = tqdm(train_loader)
            new_model.train()
            for batch_idx, (data, target) in enumerate(pbar):
                data, target = data.to(self.device), target.to(self.device)
                optimizer.zero_grad()
                output = new_model(x = data, chromosome = self)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                pbar.set_description(desc= f'epoch {epoch} loss={loss.item()} batch_id={batch_idx}')
            # Training accuracy
            '''
            correct = 0
            total = 0
            new_model.eval()
            with torch.no_grad():
                for data in train_loader:
                    images, labels = data[0].to(self.device), data[1].to(self.device)
                    outputs = new_model(images,self)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print("Training accuracy: {}".format(100 * correct / total))
            '''
            #Testing loop
            correct = 0
            total = 0
            new_model.eval()
            with torch.no_grad():
                for data in test_loader:
                    images, labels = data[0].to(self.device), data[1].to(self.device)
                    outputs = new_model(images,self)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print("Validation accuracy: {}".format(100 * correct / total))
        print(f"Fitness calculated: {100 * correct / total}")
        return 100 * correct / total

    def crossover(self, chromosome):
        genes1 = self.genes
        genes2 = chromosome.genes
        keys = genes1.keys()
        new_genes = {}
        for key in keys:
            new_genes[key] = random.choice([genes1[key], genes2[key]])
        new_chromosome = Chromosome(self.phase, self.prev_best, new_genes, self.train_loader, self.test_loader)
        return new_chromosome 
    
    def mutation(self):
        mutated_gene = random.choice(list(self.genes.keys()))
        possible_values = [value for value in SEARCH_SPACE[mutated_gene]]
        possible_values.remove(self.genes[mutated_gene])
        new_gene_value = random.choice(possible_values)
        new_genes = self.genes.copy()
        new_genes[mutated_gene] = new_gene_value
        new_chromosome = Chromosome(self.phase, self.prev_best, new_genes, self.train_loader, self.test_loader)
        return new_chromosome

In [5]:
class Generation():
    def __init__(self,
                 fit_survival_rate: float,
                 unfit_survival_rate: float,
                 mutation_rate: float,
                 pop_size: int,
                 phase: int,
                 search_space: dict,
                 prev_best: Chromosome,
                 train_loader,
                 test_loader):
        self.fit_survival_rate = fit_survival_rate
        self.unfit_survival_rate = unfit_survival_rate
        self.mutation_rate = mutation_rate
        self.pop_size = pop_size
        self.phase = phase
        self.pop = []

        for i in range(pop_size):
            self.pop.append(Chromosome(phase=phase,
                                       prev_best=prev_best,
                                       genes=self.make_gene(search_space),
                                       train_loader = train_loader,
                                       test_loader = test_loader))

    def make_gene(self, search_space: dict):
        gene = {}
        keys = search_space.keys()
        for key in keys:
            gene[key] = random.choice(search_space[key])
        if self.phase == 0:
            gene['include_b'] = True
        return gene

    def sort_pop(self):
        sorted_pop = sorted(self.pop,
                            key=lambda x: x.fitness,
                            reverse=True)
        self.pop = sorted_pop

    def generate(self):
        # print("start gen")
        self.sort_pop()
        # print(f"{[i.fitness for i in self.pop]}")
        num_fit_selected = int(self.fit_survival_rate * self.pop_size)
        num_unfit_selected = int(self.unfit_survival_rate * self.pop_size)
        num_mutate = int(self.mutation_rate * self.pop_size)

        new_pop = []

        for i in range(num_fit_selected):
            if(self.pop[i].fitness!=0):
                new_pop.append(self.pop[i])

        # print('ok')


        for i in range(num_unfit_selected):
            # print(i)
            if(self.pop[self.pop_size-i-1].fitness!=0):
                new_pop.append(self.pop[self.pop_size - i - 1])

        if (num_mutate > len(new_pop)):
            indices_to_mutate = random.sample(
                range(0, len(new_pop)), len(new_pop))
        else:
            indices_to_mutate = random.sample(
                range(0, len(new_pop)), num_mutate)
        
        for i in indices_to_mutate:
            if(new_pop[i].fitness!=0):
                new_pop[i] = new_pop[i].mutation()

        # print("Mutuation done.", [i.fitness for i in new_pop])

        parents_list = []
        for i in range(self.pop_size - len(new_pop)):
            parents = random.sample(range(0, len(new_pop)), 2)
            parents_list.append(tuple(parents))

        for p1, p2 in parents_list:
            if(new_pop[p1].fitness!=0 and new_pop[p2].fitness!=0):
                new_pop.append(new_pop[p1].crossover(new_pop[p2]))

        self.pop = new_pop
        self.pop_size = len(new_pop)
        self.sort_pop()
        # print(self.pop_size)
        print("\n\n")
        # print(f"{[i.fitness for i in self.pop]}")

    def find_fittest(self):
        self.sort_pop()
        return self.pop[0]

In [6]:
num_individuals = 15
generation = Generation(fit_survival_rate = FIT_SURVIVAL_RATE,
                        unfit_survival_rate = UNFIT_SURVIVAL_RATE,
                        mutation_rate = MUTATION_RATE,
                        pop_size = num_individuals,
                        phase = 0,
                        search_space = INIT_SEARCH_SPACE, #to initialize with no b for sure
                        prev_best = None,
                        train_loader = train_loader,
                        test_loader = test_loader)

epoch 0 loss=1.2595139741897583 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 219.38it/s]


Validation accuracy: 47.62


epoch 1 loss=1.2559446096420288 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 234.81it/s]


Validation accuracy: 57.4


epoch 2 loss=1.1645759344100952 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 191.47it/s]


Validation accuracy: 61.1


epoch 3 loss=1.4223006963729858 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 242.85it/s] 


Validation accuracy: 60.47


epoch 4 loss=0.996330976486206 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 219.73it/s]  


Validation accuracy: 61.45
Fitness calculated: 61.45


epoch 0 loss=1.1566401720046997 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 167.48it/s]


Validation accuracy: 41.28


epoch 1 loss=1.0006740093231201 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 170.30it/s]


Validation accuracy: 44.6


epoch 2 loss=1.3026589155197144 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 179.25it/s]


Validation accuracy: 45.65


epoch 3 loss=1.696632742881775 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 177.38it/s] 


Validation accuracy: 47.58


epoch 4 loss=2.001685619354248 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 176.79it/s] 


Validation accuracy: 50.45
Fitness calculated: 50.45


epoch 0 loss=1.2375341653823853 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 310.71it/s]


Validation accuracy: 56.17


epoch 1 loss=0.8809983730316162 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 306.16it/s]


Validation accuracy: 59.91


epoch 2 loss=0.6760192513465881 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 297.33it/s]


Validation accuracy: 60.49


epoch 3 loss=1.1407761573791504 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 317.99it/s]


Validation accuracy: 61.55


epoch 4 loss=0.9343262910842896 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 293.30it/s]


Validation accuracy: 62.27
Fitness calculated: 62.27


epoch 0 loss=1.7004363536834717 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 341.56it/s]


Validation accuracy: 41.49


epoch 1 loss=1.214741826057434 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 343.36it/s] 


Validation accuracy: 40.67


epoch 2 loss=1.1553350687026978 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 334.45it/s]


Validation accuracy: 44.47


epoch 3 loss=1.329593300819397 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 336.89it/s] 


Validation accuracy: 46.33


epoch 4 loss=1.1514365673065186 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 342.81it/s]


Validation accuracy: 47.33
Fitness calculated: 47.33


epoch 0 loss=1.3150838613510132 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 302.40it/s]


Validation accuracy: 38.38


epoch 1 loss=2.3550961017608643 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 301.78it/s]


Validation accuracy: 42.91


epoch 2 loss=1.852813959121704 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 316.48it/s] 


Validation accuracy: 42.58


epoch 3 loss=1.4081003665924072 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 310.93it/s]


Validation accuracy: 42.7


epoch 4 loss=1.3246389627456665 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 311.09it/s]


Validation accuracy: 43.43
Fitness calculated: 43.43


epoch 0 loss=2.6689610481262207 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 360.91it/s]


Validation accuracy: 41.27


epoch 1 loss=1.5864508152008057 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 360.12it/s]


Validation accuracy: 42.4


epoch 2 loss=1.9743738174438477 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 348.80it/s]


Validation accuracy: 42.08


epoch 3 loss=1.6497232913970947 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 371.19it/s]


Validation accuracy: 44.52


epoch 4 loss=1.2384419441223145 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 347.46it/s]


Validation accuracy: 46.18
Fitness calculated: 46.18


epoch 0 loss=1.4454225301742554 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 281.97it/s]


Validation accuracy: 48.5


epoch 1 loss=1.0213724374771118 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 283.73it/s]


Validation accuracy: 52.3


epoch 2 loss=1.389613389968872 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 286.31it/s] 


Validation accuracy: 55.49


epoch 3 loss=1.6104586124420166 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 292.20it/s]


Validation accuracy: 56.6


epoch 4 loss=1.2558165788650513 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 286.68it/s]


Validation accuracy: 57.92
Fitness calculated: 57.92


epoch 0 loss=1.1356487274169922 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 342.35it/s]


Validation accuracy: 52.05


epoch 1 loss=1.4434226751327515 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 334.90it/s]


Validation accuracy: 56.93


epoch 2 loss=0.8603779077529907 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 336.24it/s]


Validation accuracy: 58.58


epoch 3 loss=1.897050380706787 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 342.28it/s] 


Validation accuracy: 61.46


epoch 4 loss=1.5853761434555054 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 348.52it/s]


Validation accuracy: 61.11
Fitness calculated: 61.11


epoch 0 loss=1.803973913192749 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 355.24it/s] 


Validation accuracy: 42.14


epoch 1 loss=1.4398853778839111 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 359.60it/s]


Validation accuracy: 47.13


epoch 2 loss=1.3214411735534668 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 364.73it/s]


Validation accuracy: 49.46


epoch 3 loss=1.1343598365783691 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 337.55it/s]


Validation accuracy: 52.69


epoch 4 loss=1.297728180885315 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 329.92it/s] 


Validation accuracy: 53.79
Fitness calculated: 53.79


epoch 0 loss=3.763913154602051 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 219.22it/s] 


Validation accuracy: 35.11


epoch 1 loss=2.622899055480957 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 217.41it/s] 


Validation accuracy: 43.5


epoch 2 loss=1.4694297313690186 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 215.94it/s]


Validation accuracy: 47.01


epoch 3 loss=1.6593499183654785 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 217.62it/s]


Validation accuracy: 46.18


epoch 4 loss=1.3912980556488037 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 210.30it/s]


Validation accuracy: 48.13
Fitness calculated: 48.13


epoch 0 loss=1.4813742637634277 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 234.57it/s]


Validation accuracy: 46.43


epoch 1 loss=1.7546836137771606 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 233.94it/s]


Validation accuracy: 51.54


epoch 2 loss=1.15911865234375 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 233.81it/s]  


Validation accuracy: 51.57


epoch 3 loss=0.9846889972686768 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 231.67it/s]


Validation accuracy: 54.96


epoch 4 loss=0.7741592526435852 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 229.92it/s]


Validation accuracy: 54.56
Fitness calculated: 54.56


epoch 0 loss=1.4245223999023438 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 297.48it/s]


Validation accuracy: 41.22


epoch 1 loss=1.9063129425048828 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 309.96it/s]


Validation accuracy: 43.16


epoch 2 loss=1.1882703304290771 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 308.81it/s]


Validation accuracy: 42.76


epoch 3 loss=1.5015416145324707 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 301.71it/s]


Validation accuracy: 44.2


epoch 4 loss=1.8330367803573608 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 307.02it/s]


Validation accuracy: 44.84
Fitness calculated: 44.84


epoch 0 loss=0.9399345517158508 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 339.31it/s]


Validation accuracy: 57.41


epoch 1 loss=0.6611720323562622 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 351.41it/s]


Validation accuracy: 59.69


epoch 2 loss=1.3311808109283447 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 339.35it/s]


Validation accuracy: 61.8


epoch 3 loss=0.8606046438217163 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 342.62it/s]


Validation accuracy: 62.33


epoch 4 loss=1.0482661724090576 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 347.44it/s]


Validation accuracy: 62.37
Fitness calculated: 62.37


epoch 0 loss=1.966251015663147 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 362.88it/s] 


Validation accuracy: 42.25


epoch 1 loss=1.4199765920639038 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 359.74it/s]


Validation accuracy: 44.12


epoch 2 loss=1.4777016639709473 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 360.27it/s]


Validation accuracy: 46.11


epoch 3 loss=1.6937330961227417 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 365.41it/s]


Validation accuracy: 45.95


epoch 4 loss=1.7402395009994507 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 339.17it/s]


Validation accuracy: 47.81
Fitness calculated: 47.81


epoch 0 loss=1.3866429328918457 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 355.73it/s]


Validation accuracy: 45.66


epoch 1 loss=1.1366267204284668 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 356.67it/s]


Validation accuracy: 48.42


epoch 2 loss=1.6133462190628052 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 338.68it/s]


Validation accuracy: 50.43


epoch 3 loss=1.044173240661621 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 340.02it/s] 


Validation accuracy: 52.58


epoch 4 loss=0.9194194078445435 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 352.28it/s]


Validation accuracy: 52.49
Fitness calculated: 52.49


In [7]:
rounds = 15

for i in range (rounds):
   index1, index2 = random.sample(range(0, len(generation.pop)), 2)
    
   if(generation.pop[index1].fitness > generation.pop[index2].fitness):
       generation.pop.append(generation.pop[index1].mutation())
       #Kill the index2
       generation.pop.pop(index2)
       #Reproduce the first one

   else:
        generation.pop.append(generation.pop[index2].mutation())
        #Kill the index1
        generation.pop.pop(index1)
        #Reproduce the second one


epoch 0 loss=1.1069512367248535 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 369.95it/s]


Validation accuracy: 59.93


epoch 1 loss=1.2593777179718018 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 418.72it/s]


Validation accuracy: 61.38


epoch 2 loss=0.9588983654975891 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 441.02it/s]


Validation accuracy: 63.97


epoch 3 loss=0.5843693017959595 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 405.98it/s] 


Validation accuracy: 64.83


epoch 4 loss=1.0895774364471436 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 423.95it/s]


Validation accuracy: 62.11
Fitness calculated: 62.11


epoch 0 loss=1.1704005002975464 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 309.98it/s]


Validation accuracy: 49.71


epoch 1 loss=0.6819384098052979 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 311.52it/s]


Validation accuracy: 54.33


epoch 2 loss=1.705175518989563 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 307.30it/s] 


Validation accuracy: 57.28


epoch 3 loss=1.1824856996536255 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 306.53it/s]


Validation accuracy: 58.22


epoch 4 loss=0.7209721207618713 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 307.99it/s]


Validation accuracy: 56.17
Fitness calculated: 56.17


epoch 0 loss=1.420373558998108 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 427.65it/s] 


Validation accuracy: 53.89


epoch 1 loss=1.1189532279968262 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 422.29it/s]


Validation accuracy: 58.49


epoch 2 loss=0.9697616696357727 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 428.21it/s]


Validation accuracy: 60.68


epoch 3 loss=1.022669792175293 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 437.22it/s] 


Validation accuracy: 60.07


epoch 4 loss=1.178345799446106 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 414.20it/s] 


Validation accuracy: 61.54
Fitness calculated: 61.54


epoch 0 loss=1.601891279220581 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 173.54it/s] 


Validation accuracy: 48.22


epoch 1 loss=1.247334599494934 batch_id=781: 100%|██████████| 782/782 [00:05<00:00, 155.48it/s] 


Validation accuracy: 55.62


epoch 2 loss=0.7441659569740295 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 193.65it/s]


Validation accuracy: 56.24


epoch 3 loss=1.2688617706298828 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 195.62it/s]


Validation accuracy: 60.01


epoch 4 loss=0.8506298661231995 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 185.26it/s]


Validation accuracy: 58.8
Fitness calculated: 58.8


epoch 0 loss=1.2776345014572144 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 416.21it/s]


Validation accuracy: 55.67


epoch 1 loss=1.2006008625030518 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 439.89it/s]


Validation accuracy: 59.3


epoch 2 loss=0.9568990468978882 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 430.79it/s]


Validation accuracy: 58.11


epoch 3 loss=0.8751956224441528 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 430.16it/s]


Validation accuracy: 60.88


epoch 4 loss=0.6408754587173462 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 422.06it/s]


Validation accuracy: 59.48
Fitness calculated: 59.48


epoch 0 loss=0.8783440589904785 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 426.05it/s]


Validation accuracy: 60.75


epoch 1 loss=1.219024896621704 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 411.97it/s] 


Validation accuracy: 63.2


epoch 2 loss=0.9646499752998352 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 337.66it/s]


Validation accuracy: 63.31


epoch 3 loss=0.9458885192871094 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 296.45it/s]


Validation accuracy: 64.51


epoch 4 loss=1.0692243576049805 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 436.69it/s]


Validation accuracy: 63.8
Fitness calculated: 63.8


epoch 0 loss=1.559086799621582 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 431.64it/s] 


Validation accuracy: 44.74


epoch 1 loss=1.3591022491455078 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 416.35it/s]


Validation accuracy: 50.62


epoch 2 loss=1.378806233406067 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 449.66it/s] 


Validation accuracy: 52.44


epoch 3 loss=0.8597484827041626 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 439.08it/s]


Validation accuracy: 55.15


epoch 4 loss=0.8724141120910645 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 458.03it/s]


Validation accuracy: 55.78
Fitness calculated: 55.78


epoch 0 loss=0.9995945692062378 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 307.52it/s]


Validation accuracy: 51.33


epoch 1 loss=1.3931406736373901 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 305.04it/s]


Validation accuracy: 54.28


epoch 2 loss=1.4256612062454224 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 304.01it/s]


Validation accuracy: 55.0


epoch 3 loss=1.203055739402771 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 171.11it/s] 


Validation accuracy: 57.0


epoch 4 loss=1.5147924423217773 batch_id=781: 100%|██████████| 782/782 [00:05<00:00, 147.90it/s]


Validation accuracy: 55.11
Fitness calculated: 55.11


epoch 0 loss=0.9849746823310852 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 208.53it/s]


Validation accuracy: 53.44


epoch 1 loss=1.0629394054412842 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 362.73it/s]


Validation accuracy: 58.32


epoch 2 loss=1.3175643682479858 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 415.75it/s]


Validation accuracy: 60.78


epoch 3 loss=1.5135397911071777 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 386.62it/s]


Validation accuracy: 61.54


epoch 4 loss=0.9927365183830261 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 409.83it/s]


Validation accuracy: 61.69
Fitness calculated: 61.69


epoch 0 loss=1.0375068187713623 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 188.15it/s]


Validation accuracy: 50.98


epoch 1 loss=0.6979513764381409 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 183.78it/s]


Validation accuracy: 53.19


epoch 2 loss=1.2043588161468506 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 185.63it/s]


Validation accuracy: 54.14


epoch 3 loss=0.9622467756271362 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 182.21it/s]


Validation accuracy: 56.35


epoch 4 loss=0.9642923474311829 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 183.49it/s] 


Validation accuracy: 58.35
Fitness calculated: 58.35


epoch 0 loss=1.4336144924163818 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 332.01it/s]


Validation accuracy: 48.37


epoch 1 loss=1.4711706638336182 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 419.37it/s]


Validation accuracy: 54.13


epoch 2 loss=1.4368035793304443 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 407.10it/s]


Validation accuracy: 57.25


epoch 3 loss=1.2119985818862915 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 437.68it/s]


Validation accuracy: 58.88


epoch 4 loss=1.0043776035308838 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 420.40it/s]


Validation accuracy: 58.75
Fitness calculated: 58.75


epoch 0 loss=1.4095922708511353 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 259.04it/s]


Validation accuracy: 55.67


epoch 1 loss=1.953348994255066 batch_id=781: 100%|██████████| 782/782 [00:06<00:00, 119.14it/s] 


Validation accuracy: 61.65


epoch 2 loss=1.2194101810455322 batch_id=781: 100%|██████████| 782/782 [00:05<00:00, 148.04it/s]


Validation accuracy: 60.35


epoch 3 loss=1.0413413047790527 batch_id=781: 100%|██████████| 782/782 [00:05<00:00, 149.53it/s]


Validation accuracy: 63.56


epoch 4 loss=0.5213300585746765 batch_id=781: 100%|██████████| 782/782 [00:05<00:00, 152.20it/s] 


Validation accuracy: 62.61
Fitness calculated: 62.61


epoch 0 loss=1.3248752355575562 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 162.98it/s]


Validation accuracy: 56.97


epoch 1 loss=1.0811463594436646 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 271.05it/s]


Validation accuracy: 58.31


epoch 2 loss=1.3665796518325806 batch_id=781: 100%|██████████| 782/782 [00:01<00:00, 405.46it/s]


Validation accuracy: 58.5


epoch 3 loss=1.1568580865859985 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 380.22it/s]


Validation accuracy: 58.24


epoch 4 loss=0.39578527212142944 batch_id=781: 100%|██████████| 782/782 [00:02<00:00, 325.46it/s]


Validation accuracy: 59.34
Fitness calculated: 59.34


epoch 0 loss=1.2274208068847656 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 208.21it/s]


Validation accuracy: 54.53


epoch 1 loss=1.2133371829986572 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 209.30it/s]


Validation accuracy: 60.25


epoch 2 loss=1.0830453634262085 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 234.53it/s]


Validation accuracy: 60.01


epoch 3 loss=0.5943021178245544 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 214.91it/s]


Validation accuracy: 61.12


epoch 4 loss=1.6952224969863892 batch_id=781: 100%|██████████| 782/782 [00:03<00:00, 218.31it/s] 


Validation accuracy: 56.79
Fitness calculated: 56.79


epoch 0 loss=1.4252705574035645 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 181.81it/s]


Validation accuracy: 53.06


epoch 1 loss=1.1347570419311523 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 187.65it/s]


Validation accuracy: 56.84


epoch 2 loss=1.509361743927002 batch_id=781: 100%|██████████| 782/782 [00:04<00:00, 179.61it/s] 


Validation accuracy: 57.25


epoch 3 loss=1.080807089805603 batch_id=781: 100%|██████████| 782/782 [00:06<00:00, 120.41it/s] 


Validation accuracy: 56.88


epoch 4 loss=0.5793830752372742 batch_id=781: 100%|██████████| 782/782 [00:08<00:00, 93.16it/s] 


Validation accuracy: 59.2
Fitness calculated: 59.2


In [9]:
generation.sort_pop()
generation.find_fittest().fitness

63.8