Skip to content

Commit

Permalink
disable dropout for now; enable tensorboard reporting; retrained on u…
Browse files Browse the repository at this point in the history
…pdated ds
  • Loading branch information
simonSlamka committed Dec 5, 2023
1 parent 88529d1 commit ebaced4
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 14 deletions.
2 changes: 0 additions & 2 deletions .github/config.yml

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
venv/
wandb/
__pycache__/
logs/
*.jpg
*.png
*.jpeg
Expand Down
26 changes: 17 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import dlib
import gc
from tqdm import tqdm
from termcolor import colored

# ^ TODO: TRY SWIN
# ! TODO: REFACTOR TO MAKE MORE READABLE AND EASIER TO UNDERSTAND

logging.basicConfig(level=logging.WARNING)

Expand Down Expand Up @@ -223,34 +225,38 @@ def check_img_dims(dir): # sanity check to ensure all the imgs are of the dims (
posFiles = [f for f in totalFiles if "pos" in f] # grab positive class girls
negFiles = [f for f in totalFiles if "neg" in f] # grab negative class imgs

print(f"Positive faces: {len(posFiles)} | Negative imgs: {len(negFiles)}") # sanity check
print(f"Positive imgs: {len(posFiles)} | Negative imgs: {len(negFiles)}") # sanity check

userInput = input("Do you want to extract and resize faces? (yes/no): ")
if userInput.lower() == "yes":
processedFiles = []
if os.path.exists("processed.txt"):
with open("processed.txt", "r") as file:
processedFiles = file.read().splitlines()
print(f"Found {len(processedFiles)} processed files")

for file in tqdm(posFiles, desc="Processing positive imgs"):
if file not in processedFiles:
faceDir = os.path.join(dsDir, ".faces", "pos") # positive class face dir
if not os.path.exists(faceDir): # sanity check if path itself exists before saving img
os.makedirs(faceDir) # if not, create it
grab_faces(file, faceDir) # grab faces

if grab_faces(file, faceDir): # grab faces
processedFiles.append(file)
for file in tqdm(negFiles, desc="Processing negative imgs"):
if file not in processedFiles:
faceDir = os.path.join(dsDir, ".faces", "neg") # negative class face dir
destPath = os.path.join(faceDir, os.path.basename(file)) # get dest path
if not os.path.exists(faceDir):
os.makedirs(faceDir)
didWeGrab = grab_faces(file, faceDir) # grab faces
if didWeGrab:
processedFiles.append(file)
# if not didWeGrab: # ! temporarily disabling this check and the central crop to compare results !
# central_crop(file, faceDir) # central crop

totalFaces = [os.path.join(dp, f) for dp, dn, filenames in os.walk(f"{dsDir}/.faces") for f in filenames if f.endswith((".jpg", ".jpeg", ".png"))] # grab all faces for sanity checking
print(f"Total faces: {len(totalFaces)}")
print(colored(f"Positive faces: {len([f for f in totalFaces if 'pos' in f])} | Negative faces: {len([f for f in totalFaces if 'neg' in f])} | Class imbalance: {len([f for f in totalFaces if 'pos' in f]) / len([f for f in totalFaces if 'neg' in f])}", "green"))

if userInput.lower() == "yes":
for face in tqdm(totalFaces, desc="Resizing faces"):
Expand Down Expand Up @@ -369,37 +375,39 @@ def compute_metrics(evalPred):

return accuracy.compute(predictions=preds, references=labels)

config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=len(labels), id2label=id2label, label2id=label2id, dropout_rate=0.5) # load config
config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=len(labels), id2label=id2label, label2id=label2id) #, dropout_rate=0.5) # load config
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", config=config) # load base model

trainingArgs = TrainingArguments(
output_dir="./out",
logging_dir="./logs",
remove_unused_columns=False,
evaluation_strategy="steps",
save_strategy="steps",
learning_rate=5e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
gradient_accumulation_steps=4,
weight_decay=0.015,
weight_decay=0.01,
num_train_epochs=15,
warmup_ratio=0.15,
lr_scheduler_type="linear", # "polynomial", "constant_with_warmup", "constant", "cosine", "polynomial"
lr_scheduler_type="cosine", # "polynomial", "constant_with_warmup", "constant", "linear", "polynomial"
seed=69,
save_steps=15,
eval_steps=15,
save_safetensors=True,
save_total_limit=5,
logging_steps=15,
logging_steps=1,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to=["wandb", "tensorboard"],
push_to_hub=True,
hub_model_id="attraction-classifier"
)

earlyStop = EarlyStoppingCallback(
early_stopping_patience=5,
early_stopping_threshold=0.01,
early_stopping_patience=10,
#early_stopping_threshold=0.01 # disabled for now
)

trainer = Trainer(
Expand Down
2 changes: 1 addition & 1 deletion wandb/debug-internal.log
2 changes: 1 addition & 1 deletion wandb/debug.log
2 changes: 1 addition & 1 deletion wandb/latest-run

0 comments on commit ebaced4

Please sign in to comment.