In [2]:
from abc import ABC, abstractmethod

class RetrainFramework(ABC):

    @abstractmethod
    def __init__(self):
        pass

    @abstractmethod
    def getAugmentedData(self,data):
        pass

    @abstractmethod
    def retrainModel(self, model, data):
        pass

    @abstractmethod
    def evaluateModel(self, model, data):
        pass


In [None]:
class NaiveFramework(RetrainFramework):

    def __init__(self):
        pass

    def getAugmentedData(self,data):
        return data
    
    def trainModel(self, model, data):
        model.fit(data)
        

    def evaluateModel(self, model, data):
        return model.evaluate(data)
    


In [None]:
from augmentation.randaugment import RandAugment
from arguments import parse_args
from train import get_transform_dict, train
from datasets.datasets import get_datasets
from datasets.loaders import create_loaders
from utils.misc import load_dataset_indices, save_dataset_indices, get_save_path


class FixMatchFrameWork(RetrainFramework):

    def __init__(self):
        pass

    def getAugmentedData(self,data):
        strong_augmentation = RandAugment(n=4, randomized_magnitude=True)
        args= parse_args()
        transform_dict = get_transform_dict(args, strong_augmentation)
        save_path = get_save_path(args)
        initial_indices = None

        train_sets, validation_set, test_set = get_datasets(
            args.data_dir,
            args.dataset,
            args.num_labeled,
            args.num_validation,
            transform_dict["train"],
            transform_dict["train_unlabeled"],
            transform_dict["test"],
            dataset_indices=initial_indices
        )

        save_dataset_indices(save_path, train_sets, validation_set)

        (train_loader_labeled, train_loader_unlabeled), validation_loader, test_loader = create_loaders(
            args,
            train_sets["labeled"],
            train_sets["unlabeled"],
            validation_set,
            test_set,
            args.batch_size,
            mu=args.mu,
            total_iters=args.iters_per_epoch,
            num_workers=args.num_workers,
        )
        return data
    
    def trainModel(self, model, data):
        model.fit(data)
        

    def evaluateModel(self, model, data):
        return model.evaluate(data)
    
