In [None]:
from tools.DataLoader import DataLoader
dataLoader = DataLoader()
df = dataLoader.load_Imm_data()
Imm_names = dataLoader.get_label_names(df)
num_Imm = len(Imm_names)
df['label'].value_counts()

In [None]:
unlabeled_data = dataLoader.load_unlabeled_data()
unlabeled_data.head()

In [None]:
from tools.DataPreprocessor import DataPreprocessor
dataPreprocessor = DataPreprocessor()
Imm_sample_size = 32
Imm_absent_mult = 5
train_df, test_df = dataPreprocessor.sample_and_split_data(df, Imm_sample_size, Imm_absent_mult)
X_train, X_test, y_train, y_test = dataPreprocessor.prepare_data(train_df, test_df)

In [None]:
from tools.HierarchicalBertClassifier import HierarchicalBertClassifier
BERT_VERSION_PATH = '/home/saveuser/S/projects/rawan2_project/Python Code/bertbase'
# Initialize the classifier
classifier = HierarchicalBertClassifier(BERT_VERSION_PATH, num_Imm)


In [None]:
from tools.BatchActiveLearner import BatchActiveLearner
from modAL.uncertainty import uncertainty_sampling
# Initialize the batch active learner
learner = BatchActiveLearner(
    estimator=classifier,
    X_training=X_train,
    y_training=y_train,
    query_strategy=uncertainty_sampling
)

In [None]:
from tools.LabelingTool import LabelingTool
from tqdm import tqdm

labelingTool = LabelingTool()
# Active learning loop
n_queries = 10 #Number of batch queries to extract per each time running this script
pbar = tqdm(total=n_queries, desc="Active Learning")

start_idx = 20480 # unlabeled data first entry idx to start active learning from
step = 1024 #* 4
X_pool = unlabeled_data[start_idx:start_idx+step]['text'].values

while len(X_pool) > 0 and pbar.n < n_queries:
    X_pool = unlabeled_data[start_idx:start_idx+step]['text'].values
    query_idx = learner.query(X_pool)
    query_instances = [X_pool[idx] for idx in query_idx]
    y = labelingTool.assign_labels(X_pool, 
                                   query_idx,
                                   unlabeled_data,
                                   'citizenship',
                                   Imm_names)

    learner.teach(X=query_instances, y=y)

    # Remove the queried instance from the pool
    X_pool = np.delete(X_pool, query_idx, axis=0)
    # Remove the corresponding row from unlabeled_data
    unlabeled_data = unlabeled_data.drop(unlabeled_data.index[query_idx]).reset_index(drop=True)

    pbar.update(1)
    start_idx = start_idx + step
    
pbar.close()

# Make predictions on the test set
print("Making final predictions on test set...")
predictions = learner.predict(X_test)